oxios_kernel/tools/
ask_user_tool.rs1use std::collections::HashMap;
18use std::sync::Arc;
19
20use async_trait::async_trait;
21use parking_lot::Mutex;
22use serde::Deserialize;
23use serde_json::{Value, json};
24use tokio::sync::oneshot;
25use uuid::Uuid;
26
27use oxi_sdk::{AgentTool, AgentToolResult, ToolContext, ToolError};
28
29use crate::event_bus::{EventBus, KernelEvent};
30struct PendingEntry {
33 sender: oneshot::Sender<String>,
34}
35
36#[derive(Default)]
41pub struct PendingAskUser {
42 inner: Mutex<HashMap<String, PendingEntry>>,
43}
44
45impl PendingAskUser {
46 pub fn new() -> Self {
48 Self::default()
49 }
50
51 pub fn register(&self) -> (String, oneshot::Receiver<String>) {
54 let id = Uuid::new_v4().to_string();
55 let (tx, rx) = oneshot::channel();
56 self.inner
57 .lock()
58 .insert(id.clone(), PendingEntry { sender: tx });
59 (id, rx)
60 }
61
62 pub fn resolve(&self, id: &str, answer: String) -> bool {
65 let Some(entry) = self.inner.lock().remove(id) else {
66 return false;
67 };
68 let _ = entry.sender.send(answer);
71 true
72 }
73
74 pub fn cancel_all(&self) {
78 let mut guard = self.inner.lock();
79 for (_, entry) in guard.drain() {
80 drop(entry.sender);
83 }
84 }
85}
86
87pub struct AskUserTool {
95 pending: Arc<PendingAskUser>,
96 event_bus: EventBus,
97}
98
99impl AskUserTool {
100 pub fn new(pending: Arc<PendingAskUser>, event_bus: EventBus) -> Self {
102 Self { pending, event_bus }
103 }
104}
105
106impl std::fmt::Debug for AskUserTool {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 f.debug_struct("AskUserTool").finish()
109 }
110}
111
112#[derive(Debug, Deserialize)]
113struct AskUserArgs {
114 question: String,
115 #[serde(default)]
116 options: Vec<String>,
117}
118
119#[async_trait]
120impl AgentTool for AskUserTool {
121 fn name(&self) -> &str {
122 "ask_user"
123 }
124
125 fn label(&self) -> &str {
126 "Ask User"
127 }
128
129 fn description(&self) -> &'static str {
130 "Ask the user a clarifying question during task execution. \
131 Provide a `question` and optionally a list of `options` for a \
132 structured picker. Execution blocks until the user responds or \
133 the request is cancelled."
134 }
135
136 fn parameters_schema(&self) -> Value {
137 json!({
138 "type": "object",
139 "properties": {
140 "question": {
141 "type": "string",
142 "description": "The question to ask the user."
143 },
144 "options": {
145 "type": "array",
146 "items": { "type": "string" },
147 "description": "Optional list of choices for a structured picker. \
148 Omit for an open-ended text input."
149 }
150 },
151 "required": ["question"]
152 })
153 }
154
155 async fn execute(
156 &self,
157 _tool_call_id: &str,
158 params: Value,
159 _signal: Option<tokio::sync::oneshot::Receiver<()>>,
160 _ctx: &ToolContext,
161 ) -> Result<AgentToolResult, ToolError> {
162 let args: AskUserArgs =
163 serde_json::from_value(params).map_err(|e| format!("Invalid arguments: {e}"))?;
164
165 if args.question.trim().is_empty() {
166 return Err("question must not be empty".to_string());
167 }
168
169 let (id, rx) = self.pending.register();
172
173 let event = KernelEvent::AskUserRequest {
174 id: id.clone(),
175 question: args.question.clone(),
176 options: args.options.clone(),
177 };
178 if let Err(e) = self.event_bus.publish(event) {
179 self.pending.resolve(&id, String::new());
181 return Err(format!("Failed to publish AskUserRequest event: {e}"));
182 }
183
184 tracing::info!(
185 request_id = %id,
186 options = args.options.len(),
187 "ask_user: question published, awaiting user response"
188 );
189
190 let answer = match rx.await {
192 Ok(answer) => answer,
193 Err(_) => {
194 tracing::warn!(request_id = %id, "ask_user: receiver dropped before response");
195 return Err("ask_user request was cancelled before the user responded".to_string());
196 }
197 };
198
199 Ok(AgentToolResult::success(answer))
200 }
201}