use async_trait::async_trait;
use futures_core::Stream;
use serde::de::{MapAccess, Visitor};
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::pin::Pin;
use std::sync::LazyLock;
use crate::context::Context;
use crate::errors::ModuleError;
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<serde_json::Value, ModuleError>> + Send>>;
#[async_trait]
pub trait Module: Send + Sync {
fn input_schema(&self) -> serde_json::Value;
fn output_schema(&self) -> serde_json::Value;
fn description(&self) -> &str;
async fn execute(
&self,
inputs: serde_json::Value,
ctx: &Context<serde_json::Value>,
) -> Result<serde_json::Value, ModuleError>;
fn stream(
&self,
_inputs: serde_json::Value,
_ctx: &Context<serde_json::Value>,
) -> Option<ChunkStream> {
None
}
fn describe(&self) -> serde_json::Value {
serde_json::json!({
"description": self.description(),
"input_schema": self.input_schema(),
"output_schema": self.output_schema(),
})
}
fn preflight(&self) -> PreflightResult {
PreflightResult {
valid: true,
checks: vec![],
requires_approval: false,
}
}
fn on_load(&self) {}
fn on_unload(&self) {}
fn on_suspend(&self) -> Option<serde_json::Value> {
None
}
fn on_resume(&self, _state: serde_json::Value) {}
}
#[derive(Debug, Clone, Serialize)]
#[allow(clippy::struct_excessive_bools)] pub struct ModuleAnnotations {
pub readonly: bool,
pub destructive: bool,
pub idempotent: bool,
pub requires_approval: bool,
pub open_world: bool,
pub streaming: bool,
pub cacheable: bool,
pub cache_ttl: u64,
pub cache_key_fields: Option<Vec<String>>,
pub paginated: bool,
pub pagination_style: String, pub extra: HashMap<String, serde_json::Value>,
}
impl<'de> Deserialize<'de> for ModuleAnnotations {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct AnnotationsVisitor;
impl<'de> Visitor<'de> for AnnotationsVisitor {
type Value = ModuleAnnotations;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("a ModuleAnnotations JSON object")
}
fn visit_map<M>(self, mut map: M) -> Result<ModuleAnnotations, M::Error>
where
M: MapAccess<'de>,
{
let mut ann = ModuleAnnotations::default();
let mut explicit_extra: Option<HashMap<String, serde_json::Value>> = None;
let mut overflow: HashMap<String, serde_json::Value> = HashMap::new();
while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"readonly" => ann.readonly = map.next_value()?,
"destructive" => ann.destructive = map.next_value()?,
"idempotent" => ann.idempotent = map.next_value()?,
"requires_approval" => ann.requires_approval = map.next_value()?,
"open_world" => ann.open_world = map.next_value()?,
"streaming" => ann.streaming = map.next_value()?,
"cacheable" => ann.cacheable = map.next_value()?,
"cache_ttl" => ann.cache_ttl = map.next_value()?,
"cache_key_fields" => ann.cache_key_fields = map.next_value()?,
"paginated" => ann.paginated = map.next_value()?,
"pagination_style" => ann.pagination_style = map.next_value()?,
"extra" => {
let v: serde_json::Value = map.next_value()?;
explicit_extra = Some(match v {
serde_json::Value::Null => HashMap::new(),
serde_json::Value::Object(obj) => obj.into_iter().collect(),
_ => {
return Err(serde::de::Error::custom(
"ModuleAnnotations.extra must be an object",
))
}
});
}
_ => {
let v: serde_json::Value = map.next_value()?;
overflow.insert(key, v);
}
}
}
let mut merged = overflow;
if let Some(ex) = explicit_extra {
for (k, v) in ex {
merged.insert(k, v);
}
}
ann.extra = merged;
Ok(ann)
}
}
deserializer.deserialize_map(AnnotationsVisitor)
}
}
impl Default for ModuleAnnotations {
fn default() -> Self {
Self {
readonly: false,
destructive: false,
idempotent: false,
requires_approval: false,
open_world: true,
streaming: false,
cacheable: false,
cache_ttl: 0,
cache_key_fields: None,
paginated: false,
pagination_style: "cursor".to_string(),
extra: HashMap::new(),
}
}
}
pub static DEFAULT_ANNOTATIONS: LazyLock<ModuleAnnotations> =
LazyLock::new(ModuleAnnotations::default);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModuleExample {
pub title: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub inputs: serde_json::Value,
pub output: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ValidationResult {
pub valid: bool,
#[serde(default)]
pub errors: Vec<String>,
#[serde(default)]
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreflightCheckResult {
pub check: String,
pub passed: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<serde_json::Value>,
#[serde(default)]
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PreflightResult {
pub valid: bool,
pub checks: Vec<PreflightCheckResult>,
#[serde(default)]
pub requires_approval: bool,
}
impl PreflightResult {
pub fn errors(&self) -> Vec<&PreflightCheckResult> {
self.checks.iter().filter(|c| !c.passed).collect()
}
}