roder-tools 0.1.1

Agentic software development tools and SDKs for Roder.
Documentation
use std::sync::Arc;

use roder_api::tools::{
    ToolCall, ToolExecutionContext, ToolExecutor, ToolRegistry, ToolResult, ToolSpec,
};
use serde::Deserialize;
use serde_json::{Value, json};
use tokio::sync::Mutex;

use crate::files::{parse, require_nonempty, result};

pub(crate) fn register(registry: &mut ToolRegistry) -> anyhow::Result<()> {
    let plan_state = Arc::new(Mutex::new(PlanState::default()));

    registry.register(Arc::new(UpdatePlanTool { state: plan_state }))?;
    crate::goals::register(registry)?;
    registry.register(Arc::new(RequestUserInputTool))
}

#[derive(Debug, Default)]
struct PlanState {
    explanation: Option<String>,
    items: Vec<PlanItem>,
}

#[derive(Debug, Clone, Deserialize)]
struct PlanItem {
    step: String,
    status: PlanStatus,
}

#[derive(Debug, Clone, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
enum PlanStatus {
    Pending,
    InProgress,
    Completed,
}

#[derive(Debug)]
struct UpdatePlanTool {
    state: Arc<Mutex<PlanState>>,
}

#[derive(Debug)]
struct RequestUserInputTool;

#[async_trait::async_trait]
impl ToolExecutor for UpdatePlanTool {
    fn spec(&self) -> ToolSpec {
        ToolSpec {
            name: "update_plan".to_string(),
            description: "Updates the task plan.".to_string(),
            parameters: json!({
                "type": "object",
                "properties": {
                    "explanation": {
                        "type": "string",
                        "description": "Optional explanation for the plan update."
                    },
                    "plan": {
                        "type": "array",
                        "items": {
                            "type": "object",
                            "properties": {
                                "step": { "type": "string" },
                                "status": {
                                    "type": "string",
                                    "enum": ["pending", "in_progress", "completed"]
                                }
                            },
                            "required": ["step", "status"],
                            "additionalProperties": false
                        }
                    }
                },
                "required": ["plan"],
                "additionalProperties": false
            }),
        }
    }

    async fn execute(
        &self,
        _ctx: ToolExecutionContext,
        call: ToolCall,
    ) -> anyhow::Result<ToolResult> {
        let args = parse::<UpdatePlanArgs>(&call)?;
        let in_progress = args
            .plan
            .iter()
            .filter(|item| item.status == PlanStatus::InProgress)
            .count();
        if in_progress > 1 {
            return Ok(error_result(
                call,
                "update_plan accepts at most one in_progress item".to_string(),
            ));
        }
        for item in &args.plan {
            require_nonempty(item.step.trim(), "step")?;
        }

        let mut state = self.state.lock().await;
        state.explanation = args.explanation;
        state.items = args.plan;
        let text = format_plan(&state);
        Ok(result(
            call,
            text,
            json!({
                "explanation": state.explanation,
                "plan": state.items.iter().map(plan_item_json).collect::<Vec<_>>(),
            }),
            false,
        ))
    }
}

#[async_trait::async_trait]
impl ToolExecutor for RequestUserInputTool {
    fn spec(&self) -> ToolSpec {
        ToolSpec {
            name: "request_user_input".to_string(),
            description:
                "Request user input for one to three short questions and wait for the response."
                    .to_string(),
            parameters: json!({
                "type": "object",
                "properties": {
                    "questions": {
                        "type": "array",
                        "minItems": 1,
                        "maxItems": 3,
                        "items": {
                            "type": "object",
                            "properties": {
                                "header": { "type": "string" },
                                "id": { "type": "string" },
                                "question": { "type": "string" },
                                "options": {
                                    "type": "array",
                                    "items": {
                                        "type": "object",
                                        "properties": {
                                            "label": { "type": "string" },
                                            "description": { "type": "string" }
                                        },
                                        "required": ["label", "description"],
                                        "additionalProperties": false
                                    }
                                }
                            },
                            "required": ["header", "id", "question", "options"],
                            "additionalProperties": false
                        }
                    }
                },
                "required": ["questions"],
                "additionalProperties": false
            }),
        }
    }

    async fn execute(
        &self,
        _ctx: ToolExecutionContext,
        call: ToolCall,
    ) -> anyhow::Result<ToolResult> {
        let args = parse::<RequestUserInputArgs>(&call)?;
        if args.questions.is_empty() || args.questions.len() > 3 {
            return Ok(error_result(
                call,
                "request_user_input requires one to three questions".to_string(),
            ));
        }
        for question in &args.questions {
            require_nonempty(question.header.trim(), "header")?;
            require_nonempty(question.id.trim(), "id")?;
            require_nonempty(question.question.trim(), "question")?;
            if question.options.len() < 2 || question.options.len() > 3 {
                return Ok(error_result(
                    call,
                    "each request_user_input question requires two or three options".to_string(),
                ));
            }
            for option in &question.options {
                require_nonempty(option.label.trim(), "label")?;
                require_nonempty(option.description.trim(), "description")?;
            }
        }
        let request = json!({
            "request_id": call.id,
            "questions": args.questions.iter().map(user_question_json).collect::<Vec<_>>(),
        });
        Ok(result(
            call,
            "waiting for user input".to_string(),
            json!({ "user_input_request": request }),
            false,
        ))
    }
}

#[derive(Deserialize)]
struct UpdatePlanArgs {
    explanation: Option<String>,
    plan: Vec<PlanItem>,
}

#[derive(Deserialize)]
struct RequestUserInputArgs {
    questions: Vec<UserQuestion>,
}

#[derive(Deserialize)]
struct UserQuestion {
    header: String,
    id: String,
    question: String,
    options: Vec<UserInputOption>,
}

#[derive(Deserialize)]
struct UserInputOption {
    label: String,
    description: String,
}

fn error_result(call: ToolCall, message: String) -> ToolResult {
    result(
        call,
        message.clone(),
        json!({
            "error": {
                "kind": "invalid_request",
                "message": message,
            }
        }),
        true,
    )
}

fn format_plan(state: &PlanState) -> String {
    let mut text = String::new();
    if let Some(explanation) = &state.explanation
        && !explanation.trim().is_empty()
    {
        text.push_str(explanation.trim());
        text.push('\n');
    }
    for item in &state.items {
        text.push_str("- ");
        text.push_str(status_label(&item.status));
        text.push_str(": ");
        text.push_str(item.step.trim());
        text.push('\n');
    }
    text.trim_end().to_string()
}

fn plan_item_json(item: &PlanItem) -> Value {
    json!({
        "step": item.step,
        "status": status_label(&item.status),
    })
}

fn status_label(status: &PlanStatus) -> &'static str {
    match status {
        PlanStatus::Pending => "pending",
        PlanStatus::InProgress => "in_progress",
        PlanStatus::Completed => "completed",
    }
}

fn user_question_json(question: &UserQuestion) -> Value {
    json!({
        "header": question.header,
        "id": question.id,
        "question": question.question,
        "options": question.options.iter().map(user_option_json).collect::<Vec<_>>(),
    })
}

fn user_option_json(option: &UserInputOption) -> Value {
    json!({
        "label": option.label,
        "description": option.description,
    })
}

#[cfg(test)]
mod tests {
    use roder_api::events::{ThreadId, TurnId};
    use roder_api::policy_mode::PolicyMode;

    use super::*;

    #[tokio::test]
    async fn update_plan_rejects_multiple_in_progress_items() {
        let tool = UpdatePlanTool {
            state: Arc::new(Mutex::new(PlanState::default())),
        };

        let result = tool
            .execute(
                context(),
                call(
                    "update_plan",
                    json!({
                        "plan": [
                            { "step": "one", "status": "in_progress" },
                            { "step": "two", "status": "in_progress" }
                        ]
                    }),
                ),
            )
            .await
            .unwrap();

        assert!(result.is_error);
    }

    #[tokio::test]
    async fn request_user_input_returns_pending_request_payload() {
        let tool = RequestUserInputTool;

        let result = tool
            .execute(
                context(),
                call(
                    "request_user_input",
                    json!({
                        "questions": [{
                            "header": "Mode",
                            "id": "mode",
                            "question": "Which mode?",
                            "options": [
                                { "label": "Safe", "description": "Keep restrictions." },
                                { "label": "Fast", "description": "Allow more automation." }
                            ]
                        }]
                    }),
                ),
            )
            .await
            .unwrap();

        assert!(!result.is_error);
        assert_eq!(
            result.data["user_input_request"]["questions"][0]["id"],
            "mode"
        );
    }

    fn call(name: &str, arguments: Value) -> ToolCall {
        ToolCall {
            id: format!("call-{name}"),
            name: name.to_string(),
            arguments,
            raw_arguments: "{}".to_string(),
            thread_id: "thread-workflow".to_string(),
            turn_id: "turn-workflow".to_string(),
        }
    }

    fn context() -> ToolExecutionContext {
        ToolExecutionContext::new(
            ThreadId::from("thread-workflow"),
            TurnId::from("turn-workflow"),
            PolicyMode::Default,
        )
    }
}