use std::collections::{BTreeMap, BTreeSet};
use serde::{Deserialize, Serialize};
use crate::events::{EventEnvelope, ThreadId, TurnId};
use crate::extension::ProvidedService;
use crate::inference::{AgentInferenceRequest, InferenceEvent, ModelDescriptor};
use crate::tools::ToolSpec;
mod dispatch;
pub use dispatch::{
ProcessSubagentCancelParams, ProcessSubagentDefinitionsParams,
ProcessSubagentDefinitionsResult, ProcessSubagentDispatchAck, ProcessSubagentDispatchParams,
ProcessSubagentEvent, ProcessSubagentEventNotification, ProcessTaskCancelParams,
ProcessTaskEvent, ProcessTaskEventNotification, ProcessTaskExecuteAck,
ProcessTaskExecuteParams, ProcessTaskSpecParams, ProcessTaskSpecResult,
};
pub const PROCESS_EXTENSION_PROTOCOL_VERSION: &str = "0.2.0";
pub const METHOD_INITIALIZE: &str = "extension/initialize";
pub const METHOD_LIST_MODELS: &str = "inference/listModels";
pub const METHOD_STREAM_TURN: &str = "inference/streamTurn";
pub const METHOD_INFERENCE_EVENT: &str = "inference/event";
pub const METHOD_SUBAGENTS_DEFINITIONS: &str = "subagents/definitions";
pub const METHOD_SUBAGENTS_DISPATCH: &str = "subagents/dispatch";
pub const METHOD_SUBAGENTS_EVENT: &str = "subagents/event";
pub const METHOD_SUBAGENTS_CANCEL: &str = "subagents/cancel";
pub const METHOD_TASKS_SPEC: &str = "tasks/spec";
pub const METHOD_TASKS_EXECUTE: &str = "tasks/execute";
pub const METHOD_TASKS_EVENT: &str = "tasks/event";
pub const METHOD_TASKS_CANCEL: &str = "tasks/cancel";
pub const METHOD_TOOLS_CALL: &str = "tools/call";
pub const METHOD_EVENTS_HANDLE: &str = "events/handle";
pub const METHOD_EXTENSION_EVENT: &str = "extension/event";
pub const METHOD_SHUTDOWN: &str = "extension/shutdown";
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub struct ProcessExtensionConfig {
pub id: String,
#[serde(default = "default_enabled")]
pub enabled: bool,
pub manifest: String,
pub command: String,
#[serde(default)]
pub args: Vec<String>,
#[serde(default)]
pub cwd: Option<String>,
#[serde(default)]
pub env: BTreeMap<String, String>,
#[serde(default = "default_startup_timeout_ms")]
pub startup_timeout_ms: u64,
#[serde(default)]
pub event_filter: ProcessEventFilter,
}
fn default_enabled() -> bool {
true
}
fn default_startup_timeout_ms() -> u64 {
10_000
}
#[derive(Debug, Clone, Default, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProcessEventFilter {
#[serde(default)]
pub kinds: Vec<String>,
}
impl ProcessEventFilter {
pub fn matches(&self, kind: &str) -> bool {
self.kinds.iter().any(|prefix| kind.starts_with(prefix))
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ProcessExtensionManifest {
pub id: String,
pub name: String,
pub version: String,
pub api_version: String,
#[serde(default)]
pub description: Option<String>,
pub provides: Vec<ProcessProvidedService>,
#[serde(default)]
pub required_capabilities: Vec<String>,
#[serde(default)]
pub launch: Option<crate::packages::PackageExtensionLaunch>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum ProcessProvidedService {
InferenceEngine { id: String },
EventSink { id: String },
SubagentDispatcher { id: String },
TaskExecutor { id: String },
ToolProvider { id: String, tools: Vec<ToolSpec> },
}
impl ProcessProvidedService {
pub fn service_id(&self) -> &str {
match self {
ProcessProvidedService::InferenceEngine { id } => id,
ProcessProvidedService::EventSink { id } => id,
ProcessProvidedService::SubagentDispatcher { id } => id,
ProcessProvidedService::TaskExecutor { id } => id,
ProcessProvidedService::ToolProvider { id, .. } => id,
}
}
}
impl From<&ProcessProvidedService> for ProvidedService {
fn from(service: &ProcessProvidedService) -> Self {
match service {
ProcessProvidedService::InferenceEngine { id } => {
ProvidedService::InferenceEngine(id.clone())
}
ProcessProvidedService::EventSink { id } => ProvidedService::EventSink(id.clone()),
ProcessProvidedService::SubagentDispatcher { id } => {
ProvidedService::SubagentDispatcher(id.clone())
}
ProcessProvidedService::TaskExecutor { id } => {
ProvidedService::TaskExecutor(id.clone())
}
ProcessProvidedService::ToolProvider { id, .. } => {
ProvidedService::ToolProvider(id.clone())
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct ProcessInitializeParams {
pub protocol_version: String,
pub api_version: String,
pub extension_id: String,
pub cwd: String,
pub granted_capabilities: Vec<String>,
pub config: serde_json::Value,
pub event_filter: ProcessEventFilter,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ProcessInitializeResult {
pub protocol_version: String,
pub extension_id: String,
pub services: Vec<ProcessProvidedService>,
pub manifest_checksum: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct ProcessListModelsParams {
pub engine_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct ProcessListModelsResult {
pub models: Vec<ModelDescriptor>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ProcessStreamTurnParams {
pub engine_id: String,
pub stream_id: String,
pub thread_id: ThreadId,
pub turn_id: TurnId,
pub request: AgentInferenceRequest,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct ProcessStreamTurnAck {
pub stream_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ProcessToolCallParams {
pub provider_id: String,
pub tool_name: String,
pub call_id: String,
pub thread_id: ThreadId,
pub turn_id: TurnId,
pub arguments: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ProcessToolCallResult {
pub content: String,
pub is_error: bool,
#[serde(default)]
pub data: serde_json::Value,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "camelCase")]
pub struct ProcessInferenceEventNotification {
pub stream_id: String,
pub event: InferenceEvent,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ProcessEventsHandleNotification {
pub envelope: EventEnvelope,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "camelCase")]
pub struct ProcessExtensionOwnedEvent {
pub extension_id: String,
pub event_kind: String,
pub schema_version: u32,
pub payload: serde_json::Value,
}
pub fn validate_initialize_echo(
manifest: &ProcessExtensionManifest,
manifest_toml: &str,
result: &ProcessInitializeResult,
) -> anyhow::Result<()> {
anyhow::ensure!(
result.protocol_version == PROCESS_EXTENSION_PROTOCOL_VERSION,
"process extension {} speaks protocol {:?} but the host requires {:?}",
manifest.id,
result.protocol_version,
PROCESS_EXTENSION_PROTOCOL_VERSION
);
anyhow::ensure!(
result.extension_id == manifest.id,
"process extension echoed id {:?} but the manifest declares {:?}",
result.extension_id,
manifest.id
);
anyhow::ensure!(
result.services == manifest.provides,
"process extension {} echoed services {:?} but the manifest declares {:?}",
manifest.id,
result.services,
manifest.provides
);
let expected = manifest_checksum(manifest_toml);
anyhow::ensure!(
result.manifest_checksum == expected,
"process extension {} echoed manifest checksum {:?} but the configured manifest hashes \
to {:?}; the child is running against a different manifest",
manifest.id,
result.manifest_checksum,
expected
);
Ok(())
}
pub fn validate_manifest(manifest: &ProcessExtensionManifest) -> anyhow::Result<()> {
anyhow::ensure!(
!manifest.id.trim().is_empty(),
"process extension manifest is missing an id"
);
anyhow::ensure!(
!manifest.provides.is_empty(),
"process extension {} declares no provided services",
manifest.id
);
let requirement = semver::VersionReq::parse(&manifest.api_version).map_err(|err| {
anyhow::anyhow!(
"process extension {} has invalid api_version {:?}: {err}",
manifest.id,
manifest.api_version
)
})?;
let supported = semver::Version::parse(crate::extension::SUPPORTED_EXTENSION_API_VERSION)?;
anyhow::ensure!(
requirement.matches(&supported),
"process extension {} requires extension API {:?} but the host supports {}",
manifest.id,
manifest.api_version,
supported
);
for service in &manifest.provides {
let ProcessProvidedService::ToolProvider { id, tools } = service else {
continue;
};
validate_tool_provider(&manifest.id, id, tools)?;
}
Ok(())
}
fn validate_tool_provider(
extension_id: &str,
provider_id: &str,
tools: &[ToolSpec],
) -> anyhow::Result<()> {
anyhow::ensure!(
!tools.is_empty(),
"process extension {extension_id} tool provider {provider_id} declares no tools"
);
let mut names = BTreeSet::new();
for tool in tools {
anyhow::ensure!(
!tool.name.trim().is_empty(),
"process extension {extension_id} tool provider {provider_id} declares a tool with \
an empty name"
);
anyhow::ensure!(
names.insert(tool.name.as_str()),
"process extension {extension_id} tool provider {provider_id} declares tool {:?} \
more than once",
tool.name
);
let is_object_schema = tool
.parameters
.get("type")
.and_then(serde_json::Value::as_str)
== Some("object");
anyhow::ensure!(
is_object_schema,
"process extension {extension_id} tool {:?} parameters must be a JSON schema object \
(declare `type = \"object\"`)",
tool.name
);
}
Ok(())
}
pub fn manifest_checksum(manifest_toml: &str) -> String {
const OFFSET: u64 = 0xcbf2_9ce4_8422_2325;
const PRIME: u64 = 0x0000_0100_0000_01b3;
let mut hash = OFFSET;
for byte in manifest_toml.as_bytes() {
hash ^= u64::from(*byte);
hash = hash.wrapping_mul(PRIME);
}
format!("{hash:016x}")
}