mod macros;
mod manager;
mod network;
use crate::network::NETWORK;
pub use async_trait::async_trait;
use libloading::Symbol;
pub use manager::PluginManager;
pub use aiway_protocol as protocol;
use protocol::gateway::HttpContext;
pub use semver::Version;
use serde::{Deserialize, Serialize};
pub use serde_json;
use serde_json::Value;
use std::env::temp_dir;
use std::fs;
use std::fs::File;
use std::io::Write;
use std::path::PathBuf;
#[derive(Debug)]
pub enum PluginError {
ExecuteError(String),
NotFound(String),
LoadError(String),
}
impl std::fmt::Display for PluginError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PluginError::ExecuteError(msg) => write!(f, "{}", msg),
PluginError::NotFound(msg) => write!(f, "{}", msg),
PluginError::LoadError(msg) => write!(f, "{}", msg),
}
}
}
#[async_trait]
pub trait Plugin: Send + Sync {
fn name(&self) -> &str;
fn info(&self) -> PluginInfo;
async fn execute(&self, context: &HttpContext, config: &Value) -> Result<Value, PluginError>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PluginInfo {
pub version: Version,
pub default_config: Value,
pub description: String,
}
impl TryFrom<PathBuf> for Box<dyn Plugin> {
type Error = PluginError;
fn try_from(value: PathBuf) -> Result<Self, Self::Error> {
unsafe {
let lib = libloading::Library::new(&value)
.map_err(|e| PluginError::LoadError(e.to_string()))?;
let create_plugin: Symbol<unsafe extern "C" fn() -> *mut dyn Plugin> = lib
.get(b"create_plugin")
.map_err(|e| PluginError::LoadError(e.to_string()))?;
let plugin_ptr = create_plugin();
if plugin_ptr.is_null() {
return Err(PluginError::LoadError(
"Failed to create plugin: ptr is null".to_string(),
));
}
let plugin = Box::from_raw(plugin_ptr);
let wrapped_plugin = Box::new(LibraryPluginWrapper { plugin, _lib: lib });
Ok(wrapped_plugin)
}
}
}
struct LibraryPluginWrapper {
plugin: Box<dyn Plugin>,
_lib: libloading::Library,
}
#[async_trait]
impl Plugin for LibraryPluginWrapper {
fn name(&self) -> &str {
self.plugin.name()
}
fn info(&self) -> PluginInfo {
self.plugin.info()
}
async fn execute(&self, context: &HttpContext, config: &Value) -> Result<Value, PluginError> {
self.plugin.execute(context, config).await
}
}
impl Drop for LibraryPluginWrapper {
fn drop(&mut self) {
unsafe {
let destructor: Symbol<unsafe extern "C" fn(*mut dyn Plugin)> = self
._lib
.get(b"destroy_plugin")
.expect("Failed to get destructor function");
destructor(self.plugin.as_mut());
}
}
}
pub struct NetworkPlugin(pub String);
#[async_trait]
pub trait AsyncTryInto<T>: Sized {
type Error;
async fn async_try_into(self) -> Result<T, Self::Error>;
}
#[async_trait]
impl AsyncTryInto<Box<dyn Plugin>> for NetworkPlugin {
type Error = PluginError;
async fn async_try_into(self) -> Result<Box<dyn Plugin>, Self::Error> {
let response = NETWORK
.client
.get(&self.0)
.send()
.await
.map_err(|e| PluginError::LoadError(e.to_string()))?
.error_for_status()
.map_err(|e| PluginError::LoadError(e.to_string()))?;
let bytes = response
.bytes()
.await
.map_err(|e| PluginError::LoadError(e.to_string()))?;
let tpf = temp_dir().join(uuid::Uuid::new_v4().to_string());
let plugin = {
let tpf = tpf.clone();
let mut file = File::create(&tpf).map_err(|e| PluginError::LoadError(e.to_string()))?;
file.write_all(&bytes)
.map_err(|e| PluginError::LoadError(e.to_string()))?;
drop(file);
tpf.try_into()
};
fs::remove_file(tpf).map_err(|e| PluginError::LoadError(e.to_string()))?;
plugin
}
}
impl TryFrom<Vec<u8>> for Box<dyn Plugin> {
type Error = PluginError;
fn try_from(from: Vec<u8>) -> Result<Box<dyn Plugin>, Self::Error> {
let temp = temp_dir().join(format!("{}.so", uuid::Uuid::new_v4()));
fs::write(&temp, from).map_err(|e| PluginError::LoadError(e.to_string()))?;
temp.try_into()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::manager::PluginManager;
use std::io::Read;
#[tokio::test]
async fn test_network_plugin() {
let p = NetworkPlugin(
"http://192.168.1.242:10000/aiway/test/plugins/libdemo_plugin.so".to_string(),
);
let plugin: Box<dyn Plugin> = p.async_try_into().await.unwrap();
plugin
.execute(&HttpContext::default(), &Value::Null)
.await
.unwrap();
}
#[tokio::test]
async fn test_plugin_manager() {
let p = NetworkPlugin(
"http://192.168.1.242:10000/aiway/test/plugins/libdemo_plugin.so".to_string(),
);
let plugin: Box<dyn Plugin> = p.async_try_into().await.unwrap();
let mut manager = PluginManager::new();
manager.register(plugin);
manager
.run("demo", &HttpContext::default(), &Value::Null)
.await
.unwrap();
}
#[tokio::test]
async fn test_plugin_from_bytes() {
let file =
File::open("../../target/release/libaha_model_request_wrapper_plugin.so").unwrap();
let bytes = file.bytes().collect::<Result<Vec<_>, _>>().unwrap();
let plugin: Box<dyn Plugin> = bytes.try_into().unwrap();
println!("{:?}", plugin.info());
}
}