pub mod pareto_router;
pub mod pdf_input;
pub mod response_healing;
pub mod web_search;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use async_trait::async_trait;
use serde_json::{Map, Value};
use crate::canonical::{ChatRequest, ChatResponse};
use crate::config::Config;
pub struct PluginContext {
#[allow(dead_code)]
pub client: reqwest::Client,
pub settings: Map<String, Value>,
}
impl PluginContext {
pub fn get_str<'a>(&'a self, key: &str) -> Option<&'a str> {
self.settings.get(key).and_then(Value::as_str)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Stage {
Start,
PreRouting,
PostResponse,
End,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Flow {
Continue,
Modified,
#[allow(dead_code)]
Stop,
}
#[async_trait]
pub trait Plugin: Send + Sync {
fn id(&self) -> &'static str;
async fn on_start(
&self,
_ctx: &PluginContext,
_req: &mut ChatRequest,
_resp: &mut Option<ChatResponse>,
) -> anyhow::Result<Flow> {
Ok(Flow::Continue)
}
async fn pre_request(
&self,
_ctx: &PluginContext,
_req: &mut ChatRequest,
_resp: &mut Option<ChatResponse>,
) -> anyhow::Result<Flow> {
Ok(Flow::Continue)
}
async fn post_response(
&self,
_ctx: &PluginContext,
_req: &ChatRequest,
_resp: &mut Option<ChatResponse>,
) -> anyhow::Result<Flow> {
Ok(Flow::Continue)
}
async fn on_end(
&self,
_ctx: &PluginContext,
_req: &ChatRequest,
_resp: &mut Option<ChatResponse>,
) -> anyhow::Result<Flow> {
Ok(Flow::Continue)
}
}
struct PluginEntry {
plugin: Arc<dyn Plugin>,
enabled_by_default: bool,
default_settings: Map<String, Value>,
}
pub struct PluginRegistry {
entries: Vec<PluginEntry>,
by_id: HashMap<&'static str, usize>,
}
impl PluginRegistry {
pub fn from_config(config: &Config) -> Self {
let plugins: Vec<Arc<dyn Plugin>> = vec![
Arc::new(response_healing::ResponseHealingPlugin),
Arc::new(pareto_router::ParetoRouterPlugin),
Arc::new(web_search::WebSearchPlugin),
Arc::new(pdf_input::PdfInputPlugin),
];
let mut entries = Vec::with_capacity(plugins.len());
let mut by_id = HashMap::with_capacity(plugins.len());
for plugin in plugins {
let id = plugin.id();
let cfg = config.plugins.get(id);
by_id.insert(id, entries.len());
entries.push(PluginEntry {
plugin,
enabled_by_default: cfg.is_some_and(|c| c.enabled),
default_settings: cfg.map(|c| c.settings.clone()).unwrap_or_default(),
});
}
PluginRegistry { entries, by_id }
}
pub fn resolve(&self, req: &ChatRequest) -> Vec<(Arc<dyn Plugin>, Map<String, Value>)> {
let mut result = Vec::new();
let mut included = HashSet::new();
for entry in &self.entries {
if entry.enabled_by_default {
let mut settings = entry.default_settings.clone();
if let Some(req_entry) = req.plugins.iter().find(|p| p.id == entry.plugin.id()) {
settings.extend(req_entry.settings.clone());
}
included.insert(entry.plugin.id());
result.push((entry.plugin.clone(), settings));
}
}
for req_entry in &req.plugins {
if included.contains(req_entry.id.as_str()) {
continue;
}
let Some(&idx) = self.by_id.get(req_entry.id.as_str()) else {
tracing::warn!("ignoring unknown plugin id '{}'", req_entry.id);
continue;
};
let entry = &self.entries[idx];
let mut settings = entry.default_settings.clone();
settings.extend(req_entry.settings.clone());
included.insert(entry.plugin.id());
result.push((entry.plugin.clone(), settings));
}
result
}
}