#![cfg_attr(test, allow(clippy::expect_used, clippy::unwrap_used))]
use std::time::Duration;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
use tokio::process::Command;
use tracing::warn;
use crate::extensions::{
registry::RegisteredExtension,
wire::{Action, Event, EventName},
ExtensionRegistry,
};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum HookOutcome {
Continue,
Cancelled {
extension_name: String,
reason: Option<String>,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum BeforeUserMessageOutcome {
Proceed {
text: String,
attachments: Vec<crate::user_message::Attachment>,
},
Cancelled {
extension_name: String,
reason: Option<String>,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CommandOutcome {
NoOp,
Reply { text: String },
Send {
text: String,
attachments: Vec<crate::user_message::Attachment>,
},
Cancelled {
extension_name: String,
reason: Option<String>,
},
Unknown,
}
pub(crate) async fn spawn_one(extension: &RegisteredExtension, event: &Event) -> Action {
let entry = &extension.entry;
let name = entry.name.as_str();
let timeout_ms = extension.effective_timeout_ms;
let event_json = match serde_json::to_string(event) {
Ok(s) => s,
Err(err) => {
warn!("[ext:{name}] failed to serialize event: {err}");
return Action::Continue;
}
};
let mut cmd = Command::new(&entry.command);
cmd.args(&entry.args)
.stdin(std::process::Stdio::piped())
.stdout(std::process::Stdio::piped())
.stderr(std::process::Stdio::piped())
.kill_on_drop(true);
for (key, value) in &entry.env {
if std::env::var_os(key).is_none() {
cmd.env(key, value);
}
}
let mut child = match cmd.spawn() {
Ok(c) => c,
Err(err) => {
warn!("[ext:{name}] spawn failed: {err}");
return Action::Continue;
}
};
if let Some(mut stdin) = child.stdin.take() {
let line = format!("{event_json}\n");
if let Err(err) = stdin.write_all(line.as_bytes()).await {
warn!("[ext:{name}] writing stdin failed: {err}");
}
drop(stdin);
}
let stdout = match child.stdout.take() {
Some(out) => out,
None => {
warn!("[ext:{name}] no stdout pipe");
return Action::Continue;
}
};
let deadline = Duration::from_millis(timeout_ms);
let read_future = async {
let mut reader = BufReader::new(stdout);
let mut line = String::new();
reader.read_line(&mut line).await.map(|_| line)
};
let line = match tokio::time::timeout(deadline, read_future).await {
Ok(Ok(line)) => line,
Ok(Err(err)) => {
warn!("[ext:{name}] reading stdout failed: {err}");
let _ = child.kill().await;
return Action::Continue;
}
Err(_) => {
warn!("[ext:{name}] timed out after {timeout_ms}ms");
let _ = child.kill().await;
return Action::Continue;
}
};
if let Some(stderr) = child.stderr.take() {
let name = entry.name.clone();
tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut buf = String::new();
while let Ok(n) = reader.read_line(&mut buf).await {
if n == 0 {
break;
}
let line = buf.trim_end();
warn!("[ext:{name}] stderr: {line}");
buf.clear();
}
});
}
match tokio::time::timeout(Duration::from_millis(100), child.wait()).await {
Ok(Ok(status)) if !status.success() => {
warn!("[ext:{name}] exited with {status}");
return Action::Continue;
}
Ok(Ok(_)) => {}
Ok(Err(err)) => warn!("[ext:{name}] checking exit status failed: {err}"),
Err(_) => {
warn!("[ext:{name}] did not exit promptly after stdout; killed");
let _ = child.kill().await;
return Action::Continue;
}
}
let trimmed = line.trim();
if trimmed.is_empty() {
return Action::Continue;
}
match serde_json::from_str::<Action>(trimmed) {
Ok(action) => action,
Err(err) => {
let preview: String = trimmed.chars().take(200).collect();
warn!("[ext:{name}] could not parse response (`{preview}`): {err}");
Action::Continue
}
}
}
pub async fn dispatch_session_before_switch(
registry: &ExtensionRegistry,
reason: &str,
session_id: Option<&str>,
) -> HookOutcome {
let subscribed_indices = match registry.hook_index.get(&EventName::SessionBeforeSwitch) {
Some(v) => v.clone(),
None => return HookOutcome::Continue,
};
let event = Event::SessionBeforeSwitch {
reason: reason.to_string(),
session_id: session_id.map(|s| s.to_string()),
};
for idx in subscribed_indices {
let extension = ®istry.extensions[idx];
match spawn_one(extension, &event).await {
Action::Continue => continue,
Action::Cancel { reason } => {
return HookOutcome::Cancelled {
extension_name: extension.entry.name.clone(),
reason,
};
}
Action::TransformText { .. } | Action::Reply { .. } | Action::Send { .. } => {
tracing::warn!(
target: "extensions",
"[ext:{}] returned action not valid for session_before_switch; treating as continue",
extension.entry.name
);
continue;
}
}
}
HookOutcome::Continue
}
pub async fn dispatch_before_user_message(
registry: &ExtensionRegistry,
text: String,
attachments: Vec<crate::user_message::Attachment>,
) -> BeforeUserMessageOutcome {
let subscribed_indices = match registry.hook_index.get(&EventName::BeforeUserMessage) {
Some(v) => v.clone(),
None => return BeforeUserMessageOutcome::Proceed { text, attachments },
};
let mut current_text = text;
for idx in subscribed_indices {
let extension = ®istry.extensions[idx];
let event = Event::BeforeUserMessage {
text: current_text.clone(),
attachments: attachments.clone(),
};
match spawn_one(extension, &event).await {
Action::Continue => continue,
Action::Cancel { reason } => {
return BeforeUserMessageOutcome::Cancelled {
extension_name: extension.entry.name.clone(),
reason,
};
}
Action::TransformText { text } => {
current_text = text;
}
Action::Reply { .. } | Action::Send { .. } => {
tracing::warn!(
target: "extensions",
"[ext:{}] returned reply/send for before_user_message; treating as continue",
extension.entry.name
);
continue;
}
}
}
BeforeUserMessageOutcome::Proceed {
text: current_text,
attachments,
}
}
pub async fn dispatch_command(
registry: &ExtensionRegistry,
command_name: &str,
args: &str,
) -> CommandOutcome {
let idx = match registry.command_index.get(command_name) {
Some(i) => *i,
None => return CommandOutcome::Unknown,
};
let extension = ®istry.extensions[idx];
let event = Event::Command {
name: command_name.to_string(),
args: args.to_string(),
};
match spawn_one(extension, &event).await {
Action::Continue => CommandOutcome::NoOp,
Action::Cancel { reason } => CommandOutcome::Cancelled {
extension_name: extension.entry.name.clone(),
reason,
},
Action::Reply { text } => CommandOutcome::Reply { text },
Action::Send { text, attachments } => CommandOutcome::Send { text, attachments },
Action::TransformText { .. } => {
tracing::warn!(
target: "extensions",
"[ext:{}] returned transform_text for command event; treating as no-op",
extension.entry.name
);
CommandOutcome::NoOp
}
}
}