use crate::daemon::PendingAsks;
use runtime::Hook;
use serde::Deserialize;
use std::time::Duration;
use tokio::sync::oneshot;
use wcore::{ToolDispatch, ToolFuture, agent::AsTool};
const ASK_USER_TIMEOUT: Duration = Duration::from_secs(300);
#[derive(Deserialize, schemars::JsonSchema)]
pub struct QuestionOption {
pub label: String,
pub description: String,
}
#[derive(Deserialize, schemars::JsonSchema)]
pub struct Question {
pub question: String,
pub header: String,
pub options: Vec<QuestionOption>,
#[serde(default)]
pub multi_select: bool,
}
#[derive(Deserialize, schemars::JsonSchema)]
pub struct AskUser {
pub questions: Vec<Question>,
}
pub struct AskUserHook {
pending_asks: PendingAsks,
}
impl AskUserHook {
pub fn new(pending_asks: PendingAsks) -> Self {
Self { pending_asks }
}
pub fn pending_asks(&self) -> &PendingAsks {
&self.pending_asks
}
}
impl Hook for AskUserHook {
fn schema(&self) -> Vec<wcore::model::Tool> {
vec![AskUser::as_tool()]
}
fn dispatch<'a>(&'a self, name: &'a str, call: ToolDispatch) -> Option<ToolFuture<'a>> {
if name != "ask_user" {
return None;
}
Some(Box::pin(async move {
let input: AskUser =
serde_json::from_str(&call.args).map_err(|e| format!("invalid arguments: {e}"))?;
let conversation_id = call
.conversation_id
.ok_or("ask_user is only available in streaming mode")?;
let (tx, rx) = oneshot::channel();
self.pending_asks.lock().await.insert(conversation_id, tx);
match tokio::time::timeout(ASK_USER_TIMEOUT, rx).await {
Ok(Ok(reply)) => Ok(reply),
Ok(Err(_)) => {
self.pending_asks.lock().await.remove(&conversation_id);
Err("ask_user cancelled: reply channel closed".to_owned())
}
Err(_) => {
self.pending_asks.lock().await.remove(&conversation_id);
let headers: Vec<&str> =
input.questions.iter().map(|q| q.header.as_str()).collect();
Err(format!(
"ask_user timed out after {}s: no reply received for: {}",
ASK_USER_TIMEOUT.as_secs(),
headers.join("; "),
))
}
}
}))
}
}