use crate::PrComments;
use crate::logging::ToolLogCtx;
use crate::models::CommentSourceType;
use crate::models::PrSummaryList;
use crate::models::ReviewComment;
use crate::models::ReviewCommentList;
use agentic_tools_core::Tool;
use agentic_tools_core::ToolContext;
use agentic_tools_core::ToolError;
use agentic_tools_core::ToolRegistry;
use futures::future::BoxFuture;
use schemars::JsonSchema;
use serde::Deserialize;
use std::sync::Arc;
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct GetCommentsInput {
#[serde(default)]
pub pr_number: Option<u64>,
#[serde(default)]
pub comment_source_type: Option<CommentSourceType>,
#[serde(default)]
pub include_resolved: Option<bool>,
}
#[derive(Clone)]
pub struct GetCommentsTool {
pr_comments: Arc<PrComments>,
}
impl GetCommentsTool {
pub fn new(pr_comments: Arc<PrComments>) -> Self {
Self { pr_comments }
}
}
impl Tool for GetCommentsTool {
type Input = GetCommentsInput;
type Output = ReviewCommentList;
const NAME: &'static str = "gh_get_comments";
const DESCRIPTION: &'static str = "Get PR review comments with thread-level implicit pagination. Repeated calls with the same params return the next page; tool output tells you whether to call again or stop, and another identical call after completion restarts from page 1.";
fn call(
&self,
input: Self::Input,
_ctx: &ToolContext,
) -> BoxFuture<'static, Result<Self::Output, ToolError>> {
let pr_comments = Arc::clone(&self.pr_comments);
Box::pin(async move {
let log = ToolLogCtx::start(Self::NAME);
let request = serde_json::json!({
"pr_number": input.pr_number,
"comment_source_type": input.comment_source_type,
"include_resolved": input.include_resolved,
});
match pr_comments
.get_comments(
input.pr_number,
input.comment_source_type,
input.include_resolved,
)
.await
{
Ok(out) => {
log.finish(
request,
None,
true,
None,
Some(serde_json::json!({
"comments": out.comments.len(),
"shown_threads": out.shown_threads,
"total_threads": out.total_threads,
"has_more": out.has_more,
})),
None,
None,
);
Ok(out)
}
Err(e) => {
let msg = e.to_string();
log.finish(request, None, false, Some(msg), None, None, None);
Err(map_anyhow_to_tool_error(&e))
}
}
})
}
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct ListPrsInput {
#[serde(default)]
pub state: Option<String>,
}
#[derive(Clone)]
pub struct ListPrsTool {
pr_comments: Arc<PrComments>,
}
impl ListPrsTool {
pub fn new(pr_comments: Arc<PrComments>) -> Self {
Self { pr_comments }
}
}
impl Tool for ListPrsTool {
type Input = ListPrsInput;
type Output = PrSummaryList;
const NAME: &'static str = "gh_get_prs";
const DESCRIPTION: &'static str = "List pull requests in the repository with implicit pagination. Repeated calls with the same params return the next page; tool output tells you whether to call again or stop, and another identical call after completion restarts from page 1.";
fn call(
&self,
input: Self::Input,
_ctx: &ToolContext,
) -> BoxFuture<'static, Result<Self::Output, ToolError>> {
let pr_comments = Arc::clone(&self.pr_comments);
Box::pin(async move {
let log = ToolLogCtx::start(Self::NAME);
let request = serde_json::json!({
"state": input.state,
});
match pr_comments.list_prs(input.state).await {
Ok(out) => {
log.finish(
request,
None,
true,
None,
Some(serde_json::json!({ "prs": out.prs.len() })),
None,
None,
);
Ok(out)
}
Err(e) => {
let msg = e.to_string();
log.finish(request, None, false, Some(msg), None, None, None);
Err(map_anyhow_to_tool_error(&e))
}
}
})
}
}
#[derive(Debug, Clone, Deserialize, JsonSchema)]
pub struct AddCommentReplyInput {
#[serde(default)]
pub pr_number: Option<u64>,
pub comment_id: u64,
pub body: String,
}
#[derive(Clone)]
pub struct AddCommentReplyTool {
pr_comments: Arc<PrComments>,
}
impl AddCommentReplyTool {
pub fn new(pr_comments: Arc<PrComments>) -> Self {
Self { pr_comments }
}
}
impl Tool for AddCommentReplyTool {
type Input = AddCommentReplyInput;
type Output = ReviewComment;
const NAME: &'static str = "gh_add_comment_reply";
const DESCRIPTION: &'static str = "Reply to a PR review comment. Automatically prefixes with AI identifier to clearly mark automated responses.";
fn call(
&self,
input: Self::Input,
_ctx: &ToolContext,
) -> BoxFuture<'static, Result<Self::Output, ToolError>> {
let pr_comments = Arc::clone(&self.pr_comments);
Box::pin(async move {
let log = ToolLogCtx::start(Self::NAME);
let request = serde_json::json!({
"pr_number": input.pr_number,
"comment_id": input.comment_id,
"body_len": input.body.len(),
});
match pr_comments
.add_comment_reply(input.pr_number, input.comment_id, input.body)
.await
{
Ok(out) => {
log.finish(
request,
None,
true,
None,
Some(serde_json::json!({ "reply_id": out.id })),
None,
None,
);
Ok(out)
}
Err(e) => {
let msg = e.to_string();
log.finish(request, None, false, Some(msg), None, None, None);
Err(map_anyhow_to_tool_error(&e))
}
}
})
}
}
pub fn build_registry(pr_comments: Arc<PrComments>) -> ToolRegistry {
ToolRegistry::builder()
.register::<GetCommentsTool, ()>(GetCommentsTool::new(Arc::clone(&pr_comments)))
.register::<ListPrsTool, ()>(ListPrsTool::new(Arc::clone(&pr_comments)))
.register::<AddCommentReplyTool, ()>(AddCommentReplyTool::new(pr_comments))
.finish()
}
fn map_anyhow_to_tool_error(e: &anyhow::Error) -> ToolError {
let msg = e.to_string();
let lc = msg.to_lowercase();
if lc.contains("permission") || lc.contains("401") || lc.contains("403") {
ToolError::Permission(msg)
} else if lc.contains("not found") || lc.contains("404") {
ToolError::NotFound(msg)
} else if lc.contains("invalid") || lc.contains("bad request") {
ToolError::InvalidInput(msg)
} else if lc.contains("timeout") || lc.contains("network") {
ToolError::External(msg)
} else {
ToolError::Internal(msg)
}
}