use async_trait::async_trait;
use futures_core::Stream;
use serde::de::{MapAccess, Visitor};
use serde::{Deserialize, Deserializer, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::pin::Pin;
use std::sync::LazyLock;
use crate::context::Context;
use crate::errors::ModuleError;
pub type ChunkStream = Pin<Box<dyn Stream<Item = Result<serde_json::Value, ModuleError>> + Send>>;
#[async_trait]
pub trait Module: Send + Sync {
fn input_schema(&self) -> serde_json::Value;
fn output_schema(&self) -> serde_json::Value;
fn description(&self) -> &str;
async fn execute(
&self,
inputs: serde_json::Value,
ctx: &Context<serde_json::Value>,
) -> Result<serde_json::Value, ModuleError>;
fn stream(
&self,
_inputs: serde_json::Value,
_ctx: &Context<serde_json::Value>,
) -> Option<ChunkStream> {
None
}
fn describe(&self) -> serde_json::Value {
serde_json::json!({
"description": self.description(),
"input_schema": self.input_schema(),
"output_schema": self.output_schema(),
})
}
fn preflight(
&self,
_inputs: &serde_json::Value,
_ctx: Option<&Context<serde_json::Value>>,
) -> Vec<String> {
Vec::new()
}
fn tags(&self) -> Vec<String> {
Vec::new()
}
fn preview(
&self,
_inputs: &serde_json::Value,
_ctx: Option<&Context<serde_json::Value>>,
) -> Option<PreviewResult> {
None
}
fn on_load(&self) -> Result<(), ModuleError> {
Ok(())
}
fn on_unload(&self) {}
fn on_suspend(&self) -> Option<serde_json::Value> {
None
}
fn on_resume(&self, _state: serde_json::Value) {}
fn as_streaming(&self) -> Option<&dyn StreamingModule> {
None
}
}
pub trait StreamingModule: Module {
fn stream_typed(
&self,
inputs: serde_json::Value,
context: &crate::context::Context<serde_json::Value>,
) -> ChunkStream;
}
#[derive(Debug, Clone, Serialize)]
#[allow(clippy::struct_excessive_bools)] pub struct ModuleAnnotations {
pub readonly: bool,
pub destructive: bool,
pub idempotent: bool,
pub requires_approval: bool,
pub open_world: bool,
pub streaming: bool,
pub cacheable: bool,
pub cache_ttl: u64,
pub cache_key_fields: Option<Vec<String>>,
pub paginated: bool,
pub pagination_style: String, pub discoverable: bool,
pub extra: HashMap<String, serde_json::Value>,
}
impl<'de> Deserialize<'de> for ModuleAnnotations {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct AnnotationsVisitor;
impl<'de> Visitor<'de> for AnnotationsVisitor {
type Value = ModuleAnnotations;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("a ModuleAnnotations JSON object")
}
fn visit_map<M>(self, mut map: M) -> Result<ModuleAnnotations, M::Error>
where
M: MapAccess<'de>,
{
let mut ann = ModuleAnnotations::default();
let mut explicit_extra: Option<HashMap<String, serde_json::Value>> = None;
let mut overflow: HashMap<String, serde_json::Value> = HashMap::new();
while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"readonly" => ann.readonly = map.next_value()?,
"destructive" => ann.destructive = map.next_value()?,
"idempotent" => ann.idempotent = map.next_value()?,
"requires_approval" => ann.requires_approval = map.next_value()?,
"open_world" => ann.open_world = map.next_value()?,
"streaming" => ann.streaming = map.next_value()?,
"cacheable" => ann.cacheable = map.next_value()?,
"cache_ttl" => ann.cache_ttl = map.next_value()?,
"cache_key_fields" => ann.cache_key_fields = map.next_value()?,
"paginated" => ann.paginated = map.next_value()?,
"pagination_style" => ann.pagination_style = map.next_value()?,
"discoverable" => ann.discoverable = map.next_value()?,
"extra" => {
let v: serde_json::Value = map.next_value()?;
explicit_extra = Some(match v {
serde_json::Value::Null => HashMap::new(),
serde_json::Value::Object(obj) => obj.into_iter().collect(),
_ => {
return Err(serde::de::Error::custom(
"ModuleAnnotations.extra must be an object",
))
}
});
}
_ => {
let v: serde_json::Value = map.next_value()?;
overflow.insert(key, v);
}
}
}
let mut merged = overflow;
if let Some(ex) = explicit_extra {
for (k, v) in ex {
merged.insert(k, v);
}
}
ann.extra = merged;
Ok(ann)
}
}
deserializer.deserialize_map(AnnotationsVisitor)
}
}
impl Default for ModuleAnnotations {
fn default() -> Self {
Self {
readonly: false,
destructive: false,
idempotent: false,
requires_approval: false,
open_world: true,
streaming: false,
cacheable: false,
cache_ttl: 0,
cache_key_fields: None,
paginated: false,
pagination_style: "cursor".to_string(),
discoverable: true,
extra: HashMap::new(),
}
}
}
pub static DEFAULT_ANNOTATIONS: LazyLock<ModuleAnnotations> =
LazyLock::new(ModuleAnnotations::default);
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ModuleExample {
pub title: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
pub inputs: serde_json::Value,
pub output: serde_json::Value,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct ValidationResult {
pub valid: bool,
#[serde(default)]
pub errors: Vec<String>,
#[serde(default)]
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct PreflightCheckResult {
pub check: String,
pub passed: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<serde_json::Value>,
#[serde(default)]
pub warnings: Vec<String>,
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct PreflightResult {
pub valid: bool,
pub checks: Vec<PreflightCheckResult>,
#[serde(default)]
pub requires_approval: bool,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub predicted_changes: Vec<Change>,
}
#[derive(Debug, Clone, Default, Serialize)]
#[non_exhaustive]
pub struct Change {
pub action: String,
pub target: String,
pub summary: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub before: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub after: Option<serde_json::Value>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
impl<'de> Deserialize<'de> for Change {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct ChangeVisitor;
impl<'de> Visitor<'de> for ChangeVisitor {
type Value = Change;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(
"a Change JSON object with action/target/summary and optional x-* extras",
)
}
fn visit_map<M>(self, mut map: M) -> Result<Change, M::Error>
where
M: MapAccess<'de>,
{
let mut action: Option<String> = None;
let mut target: Option<String> = None;
let mut summary: Option<String> = None;
let mut before: Option<serde_json::Value> = None;
let mut after: Option<serde_json::Value> = None;
let mut extra: HashMap<String, serde_json::Value> = HashMap::new();
while let Some(key) = map.next_key::<String>()? {
match key.as_str() {
"action" => action = Some(map.next_value()?),
"target" => target = Some(map.next_value()?),
"summary" => summary = Some(map.next_value()?),
"before" => before = Some(map.next_value()?),
"after" => after = Some(map.next_value()?),
other => {
if !other.starts_with("x-") {
return Err(serde::de::Error::custom(format!(
"Change has unknown key '{other}'; extension keys must start with 'x-'"
)));
}
let v: serde_json::Value = map.next_value()?;
extra.insert(other.to_string(), v);
}
}
}
Ok(Change {
action: action.ok_or_else(|| serde::de::Error::missing_field("action"))?,
target: target.ok_or_else(|| serde::de::Error::missing_field("target"))?,
summary: summary.ok_or_else(|| serde::de::Error::missing_field("summary"))?,
before,
after,
extra,
})
}
}
deserializer.deserialize_map(ChangeVisitor)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
#[non_exhaustive]
pub struct PreviewResult {
#[serde(default)]
pub changes: Vec<Change>,
}
pub const MODULE_PREVIEW_CHECK_NAME: &str = "module_preview";
impl PreflightResult {
#[must_use]
pub fn errors(&self) -> Vec<&PreflightCheckResult> {
self.checks.iter().filter(|c| !c.passed).collect()
}
#[must_use]
pub fn errors_as_json(&self) -> Vec<serde_json::Value> {
self.checks
.iter()
.filter(|c| !c.passed)
.filter_map(|c| serde_json::to_value(c).ok())
.collect()
}
}