use std::sync::Arc;
use roder_api::tools::{
ToolCall, ToolExecutionContext, ToolExecutor, ToolRegistry, ToolResult, ToolSpec,
};
use serde::Deserialize;
use serde_json::{Value, json};
use tokio::sync::Mutex;
use crate::files::{parse, require_nonempty, result};
pub(crate) fn register(registry: &mut ToolRegistry) -> anyhow::Result<()> {
let plan_state = Arc::new(Mutex::new(PlanState::default()));
registry.register(Arc::new(UpdatePlanTool { state: plan_state }))?;
crate::goals::register(registry)?;
registry.register(Arc::new(RequestUserInputTool))
}
#[derive(Debug, Default)]
struct PlanState {
explanation: Option<String>,
items: Vec<PlanItem>,
}
#[derive(Debug, Clone, Deserialize)]
struct PlanItem {
step: String,
status: PlanStatus,
}
#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum PlanStatus {
Pending,
InProgress,
Completed,
}
#[derive(Debug)]
struct UpdatePlanTool {
state: Arc<Mutex<PlanState>>,
}
#[derive(Debug)]
struct RequestUserInputTool;
#[async_trait::async_trait]
impl ToolExecutor for UpdatePlanTool {
fn spec(&self) -> ToolSpec {
ToolSpec {
name: "update_plan".to_string(),
description: "Updates the task plan.".to_string(),
parameters: json!({
"type": "object",
"properties": {
"explanation": {
"type": "string",
"description": "Optional explanation for the plan update."
},
"plan": {
"type": "array",
"items": {
"type": "object",
"properties": {
"step": { "type": "string" },
"status": {
"type": "string",
"enum": ["pending", "in_progress", "completed"]
}
},
"required": ["step", "status"],
"additionalProperties": false
}
}
},
"required": ["plan"],
"additionalProperties": false
}),
}
}
async fn execute(
&self,
_ctx: ToolExecutionContext,
call: ToolCall,
) -> anyhow::Result<ToolResult> {
let args = parse::<UpdatePlanArgs>(&call)?;
let in_progress = args
.plan
.iter()
.filter(|item| item.status == PlanStatus::InProgress)
.count();
if in_progress > 1 {
return Ok(error_result(
call,
"update_plan accepts at most one in_progress item".to_string(),
));
}
for item in &args.plan {
require_nonempty(item.step.trim(), "step")?;
}
let mut state = self.state.lock().await;
state.explanation = args.explanation;
state.items = args.plan;
let text = format_plan(&state);
Ok(result(
call,
text,
json!({
"explanation": state.explanation,
"plan": state.items.iter().map(plan_item_json).collect::<Vec<_>>(),
}),
false,
))
}
}
#[async_trait::async_trait]
impl ToolExecutor for RequestUserInputTool {
fn spec(&self) -> ToolSpec {
ToolSpec {
name: "request_user_input".to_string(),
description:
"Request user input for one to three short questions and wait for the response."
.to_string(),
parameters: json!({
"type": "object",
"properties": {
"questions": {
"type": "array",
"minItems": 1,
"maxItems": 3,
"items": {
"type": "object",
"properties": {
"header": { "type": "string" },
"id": { "type": "string" },
"question": { "type": "string" },
"options": {
"type": "array",
"items": {
"type": "object",
"properties": {
"label": { "type": "string" },
"description": { "type": "string" }
},
"required": ["label", "description"],
"additionalProperties": false
}
}
},
"required": ["header", "id", "question", "options"],
"additionalProperties": false
}
}
},
"required": ["questions"],
"additionalProperties": false
}),
}
}
async fn execute(
&self,
_ctx: ToolExecutionContext,
call: ToolCall,
) -> anyhow::Result<ToolResult> {
let args = parse::<RequestUserInputArgs>(&call)?;
if args.questions.is_empty() || args.questions.len() > 3 {
return Ok(error_result(
call,
"request_user_input requires one to three questions".to_string(),
));
}
for question in &args.questions {
require_nonempty(question.header.trim(), "header")?;
require_nonempty(question.id.trim(), "id")?;
require_nonempty(question.question.trim(), "question")?;
if question.options.len() < 2 || question.options.len() > 3 {
return Ok(error_result(
call,
"each request_user_input question requires two or three options".to_string(),
));
}
for option in &question.options {
require_nonempty(option.label.trim(), "label")?;
require_nonempty(option.description.trim(), "description")?;
}
}
let request = json!({
"request_id": call.id,
"questions": args.questions.iter().map(user_question_json).collect::<Vec<_>>(),
});
Ok(result(
call,
"waiting for user input".to_string(),
json!({ "user_input_request": request }),
false,
))
}
}
#[derive(Deserialize)]
struct UpdatePlanArgs {
explanation: Option<String>,
plan: Vec<PlanItem>,
}
#[derive(Deserialize)]
struct RequestUserInputArgs {
questions: Vec<UserQuestion>,
}
#[derive(Deserialize)]
struct UserQuestion {
header: String,
id: String,
question: String,
options: Vec<UserInputOption>,
}
#[derive(Deserialize)]
struct UserInputOption {
label: String,
description: String,
}
fn error_result(call: ToolCall, message: String) -> ToolResult {
result(
call,
message.clone(),
json!({
"error": {
"kind": "invalid_request",
"message": message,
}
}),
true,
)
}
fn format_plan(state: &PlanState) -> String {
let mut text = String::new();
if let Some(explanation) = &state.explanation
&& !explanation.trim().is_empty()
{
text.push_str(explanation.trim());
text.push('\n');
}
for item in &state.items {
text.push_str("- ");
text.push_str(status_label(&item.status));
text.push_str(": ");
text.push_str(item.step.trim());
text.push('\n');
}
text.trim_end().to_string()
}
fn plan_item_json(item: &PlanItem) -> Value {
json!({
"step": item.step,
"status": status_label(&item.status),
})
}
fn status_label(status: &PlanStatus) -> &'static str {
match status {
PlanStatus::Pending => "pending",
PlanStatus::InProgress => "in_progress",
PlanStatus::Completed => "completed",
}
}
fn user_question_json(question: &UserQuestion) -> Value {
json!({
"header": question.header,
"id": question.id,
"question": question.question,
"options": question.options.iter().map(user_option_json).collect::<Vec<_>>(),
})
}
fn user_option_json(option: &UserInputOption) -> Value {
json!({
"label": option.label,
"description": option.description,
})
}
#[cfg(test)]
mod tests {
use roder_api::events::{ThreadId, TurnId};
use roder_api::policy_mode::PolicyMode;
use super::*;
#[tokio::test]
async fn update_plan_rejects_multiple_in_progress_items() {
let tool = UpdatePlanTool {
state: Arc::new(Mutex::new(PlanState::default())),
};
let result = tool
.execute(
context(),
call(
"update_plan",
json!({
"plan": [
{ "step": "one", "status": "in_progress" },
{ "step": "two", "status": "in_progress" }
]
}),
),
)
.await
.unwrap();
assert!(result.is_error);
}
#[tokio::test]
async fn request_user_input_returns_pending_request_payload() {
let tool = RequestUserInputTool;
let result = tool
.execute(
context(),
call(
"request_user_input",
json!({
"questions": [{
"header": "Mode",
"id": "mode",
"question": "Which mode?",
"options": [
{ "label": "Safe", "description": "Keep restrictions." },
{ "label": "Fast", "description": "Allow more automation." }
]
}]
}),
),
)
.await
.unwrap();
assert!(!result.is_error);
assert_eq!(
result.data["user_input_request"]["questions"][0]["id"],
"mode"
);
}
fn call(name: &str, arguments: Value) -> ToolCall {
ToolCall {
id: format!("call-{name}"),
name: name.to_string(),
arguments,
raw_arguments: "{}".to_string(),
thread_id: "thread-workflow".to_string(),
turn_id: "turn-workflow".to_string(),
}
}
fn context() -> ToolExecutionContext {
ToolExecutionContext::new(
ThreadId::from("thread-workflow"),
TurnId::from("turn-workflow"),
PolicyMode::Default,
)
}
}