use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use parking_lot::Mutex;
use serde::Deserialize;
use serde_json::{Value, json};
use tokio::sync::oneshot;
use uuid::Uuid;
use oxi_sdk::{AgentTool, AgentToolResult, ToolContext, ToolError};
use crate::event_bus::{EventBus, KernelEvent};
struct PendingEntry {
sender: oneshot::Sender<String>,
}
#[derive(Default)]
pub struct PendingAskUser {
inner: Mutex<HashMap<String, PendingEntry>>,
}
impl PendingAskUser {
pub fn new() -> Self {
Self::default()
}
pub fn register(&self) -> (String, oneshot::Receiver<String>) {
let id = Uuid::new_v4().to_string();
let (tx, rx) = oneshot::channel();
self.inner
.lock()
.insert(id.clone(), PendingEntry { sender: tx });
(id, rx)
}
pub fn resolve(&self, id: &str, answer: String) -> bool {
let Some(entry) = self.inner.lock().remove(id) else {
return false;
};
let _ = entry.sender.send(answer);
true
}
pub fn cancel_all(&self) {
let mut guard = self.inner.lock();
for (_, entry) in guard.drain() {
drop(entry.sender);
}
}
}
pub struct AskUserTool {
pending: Arc<PendingAskUser>,
event_bus: EventBus,
}
impl AskUserTool {
pub fn new(pending: Arc<PendingAskUser>, event_bus: EventBus) -> Self {
Self { pending, event_bus }
}
}
impl std::fmt::Debug for AskUserTool {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AskUserTool").finish()
}
}
#[derive(Debug, Deserialize)]
struct AskUserArgs {
question: String,
#[serde(default)]
options: Vec<String>,
}
#[async_trait]
impl AgentTool for AskUserTool {
fn name(&self) -> &str {
"ask_user"
}
fn label(&self) -> &str {
"Ask User"
}
fn description(&self) -> &'static str {
"Ask the user a clarifying question during task execution. \
Provide a `question` and optionally a list of `options` for a \
structured picker. Execution blocks until the user responds or \
the request is cancelled."
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"question": {
"type": "string",
"description": "The question to ask the user."
},
"options": {
"type": "array",
"items": { "type": "string" },
"description": "Optional list of choices for a structured picker. \
Omit for an open-ended text input."
}
},
"required": ["question"]
})
}
async fn execute(
&self,
_tool_call_id: &str,
params: Value,
_signal: Option<tokio::sync::oneshot::Receiver<()>>,
_ctx: &ToolContext,
) -> Result<AgentToolResult, ToolError> {
let args: AskUserArgs =
serde_json::from_value(params).map_err(|e| format!("Invalid arguments: {e}"))?;
if args.question.trim().is_empty() {
return Err("question must not be empty".to_string());
}
let (id, rx) = self.pending.register();
let event = KernelEvent::AskUserRequest {
id: id.clone(),
question: args.question.clone(),
options: args.options.clone(),
};
if let Err(e) = self.event_bus.publish(event) {
self.pending.resolve(&id, String::new());
return Err(format!("Failed to publish AskUserRequest event: {e}"));
}
tracing::info!(
request_id = %id,
options = args.options.len(),
"ask_user: question published, awaiting user response"
);
let answer = match rx.await {
Ok(answer) => answer,
Err(_) => {
tracing::warn!(request_id = %id, "ask_user: receiver dropped before response");
return Err("ask_user request was cancelled before the user responded".to_string());
}
};
Ok(AgentToolResult::success(answer))
}
}