use std::fmt::{Debug, Display};
use std::marker::PhantomData;
use std::sync::Arc;
use agent_client_protocol::mcp_server::{McpConnectionTo, McpServer, McpServerBuilder};
use agent_client_protocol::role::{HasPeer, Role};
use agent_client_protocol::schema::{
PermissionOptionKind, RequestPermissionOutcome, RequestPermissionRequest,
RequestPermissionResponse, SelectedPermissionOutcome, SessionNotification, StopReason,
};
use agent_client_protocol::util::MatchDispatch;
use agent_client_protocol::{Agent, BoxFuture, ConnectionTo, NullRun, RunWithConnectionTo};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use tracing::{debug, info, trace, warn};
use crate::Error;
pub trait ThinkObserver: Send + Sync {
fn on_prompt(&self, _prompt: &str) {}
fn on_notification(&self, _notification: &SessionNotification) {}
fn on_permission_request(&self, _request: &RequestPermissionRequest) {}
fn on_stop(&self, _reason: &StopReason) {}
}
pub struct ThinkBuilder<'bound, Output, R: Role = Agent, Run: RunWithConnectionTo<R> = NullRun>
where
R: HasPeer<Agent>,
{
cx: ConnectionTo<R>,
segments: Vec<Segment>,
server: McpServerBuilder<R, Run>,
explicit_spacing: bool,
observer: Option<Arc<dyn ThinkObserver>>,
phantom: PhantomData<fn(&'bound Run) -> Output>,
}
enum Segment {
Text(String),
ToolReference(String),
}
impl<'bound, Output, R: Role> ThinkBuilder<'bound, Output, R, NullRun>
where
R: HasPeer<Agent>,
Output: Send + JsonSchema + DeserializeOwned + 'static,
{
pub(crate) fn new(cx: ConnectionTo<R>, observer: Option<Arc<dyn ThinkObserver>>) -> Self {
Self {
cx,
segments: Vec::new(),
server: McpServer::builder("patchwork".to_string())
.instructions("You have access to tools. Call return_result when done."),
explicit_spacing: false,
observer,
phantom: PhantomData,
}
.textln("Please complete the following task to the best of your ability,")
.textln("No further instructions will be given,")
.textln(
"so do your best to interpret the instructions without further feedback from the user,",
)
.textln("making use of the tools you have available.")
.textln("")
.textln(
"IMPORTANT: When complete, invoke the `return_result` tool with the requested result.",
)
.textln("")
}
}
impl<'bound, Output, R: Role, Run: RunWithConnectionTo<R>> ThinkBuilder<'bound, Output, R, Run>
where
R: HasPeer<Agent>,
Output: Send + JsonSchema + DeserializeOwned + 'static,
{
pub fn text(mut self, text: &str) -> Self {
self.segments.push(Segment::Text(text.to_string()));
self
}
pub fn textln(mut self, text: &str) -> Self {
self.segments.push(Segment::Text(format!("{text}\n")));
self
}
pub fn display(mut self, value: &impl Display) -> Self {
self.segments.push(Segment::Text(value.to_string()));
self
}
pub fn debug(mut self, value: &impl Debug) -> Self {
self.segments.push(Segment::Text(format!("{:?}", value)));
self
}
pub fn explicit_spacing(mut self) -> Self {
self.explicit_spacing = true;
self
}
fn build_prompt(&self) -> String {
let mut result = String::new();
for (i, segment) in self.segments.iter().enumerate() {
let text = match segment {
Segment::Text(t) => t.as_str(),
Segment::ToolReference(name) => {
result.push_str(&format!("<mcp_tool>{}</mcp_tool>", name));
continue;
}
};
if !self.explicit_spacing && i > 0 && !result.is_empty() {
let needs_space = !result.ends_with([' ', '\t', '\n', '(', '[', '{'])
&& !text.starts_with(['.', ',', ':', ';', '!', '?']);
if needs_space {
result.push(' ');
}
}
result.push_str(text);
}
result
}
pub fn tool<I, O, F, H>(
mut self,
name: &str,
description: &str,
func: F,
tool_future_hack: H,
) -> ThinkBuilder<'bound, Output, R, impl RunWithConnectionTo<R>>
where
I: JsonSchema + DeserializeOwned + Send + 'static,
O: JsonSchema + Serialize + Send + 'static,
F: AsyncFnMut(I, McpConnectionTo<R>) -> Result<O, agent_client_protocol::Error> + Send,
H: for<'a> Fn(
&'a mut F,
I,
McpConnectionTo<R>,
) -> BoxFuture<'a, Result<O, agent_client_protocol::Error>>
+ Send
+ 'static,
{
debug!(tool_name = name, "registering tool");
self.segments.push(Segment::ToolReference(name.to_string()));
ThinkBuilder {
cx: self.cx,
segments: self.segments,
server: self
.server
.tool_fn_mut(name, description, func, tool_future_hack),
explicit_spacing: self.explicit_spacing,
observer: self.observer,
phantom: PhantomData,
}
}
pub fn define_tool<I, O, F, H>(
self,
name: &str,
description: &str,
func: F,
tool_future_hack: H,
) -> ThinkBuilder<'bound, Output, R, impl RunWithConnectionTo<R>>
where
I: JsonSchema + DeserializeOwned + Send + 'static,
O: JsonSchema + Serialize + Send + 'static,
F: AsyncFnMut(I, McpConnectionTo<R>) -> Result<O, agent_client_protocol::Error> + Send,
H: for<'a> Fn(
&'a mut F,
I,
McpConnectionTo<R>,
) -> BoxFuture<'a, Result<O, agent_client_protocol::Error>>
+ Send
+ 'static,
{
debug!(tool_name = name, "defining tool (hidden from prompt)");
ThinkBuilder {
cx: self.cx,
segments: self.segments,
server: self
.server
.tool_fn_mut(name, description, func, tool_future_hack),
explicit_spacing: self.explicit_spacing,
observer: self.observer,
phantom: PhantomData,
}
}
}
impl<'bound, Output, R: Role, Run: RunWithConnectionTo<R>> IntoFuture for ThinkBuilder<'bound, Output, R, Run>
where
R: HasPeer<Agent>,
Output: Send + JsonSchema + DeserializeOwned + 'static,
Run: Send,
{
type Output = Result<Output, Error>;
type IntoFuture = BoxFuture<'bound, Result<Output, Error>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move {
let prompt = self.build_prompt();
let cx = self.cx;
let observer = self.observer;
let mut output: Option<Output> = None;
let server = self.server.tool_fn_mut(
"return_result",
"Return the final result. Call this when you have completed the task.",
async |input: ReturnResultInput<Output>, _cx| {
debug!("return_result tool invoked");
output = Some(input.result);
Ok(ReturnResultOutput { success: true })
},
agent_client_protocol::tool_fn_mut!(),
);
if let Some(observer) = &observer {
observer.on_prompt(&prompt);
}
info!(prompt_len = prompt.len(), "executing think block");
trace!(prompt = %prompt, "full prompt");
let cwd = std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from("/"));
cx.build_session(&cwd)
.with_mcp_server(server.build())?
.block_task()
.run_until(async |mut session| {
session.send_prompt(&prompt)?;
tracing::info!(?prompt, "sending prompt");
loop {
let update = session.read_update().await?;
trace!(?update, "received session update");
match update {
agent_client_protocol::SessionMessage::StopReason(reason) => {
debug!(?reason, "session stopped");
if let Some(observer) = &observer {
observer.on_stop(&reason);
}
break;
}
agent_client_protocol::SessionMessage::SessionMessage(dispatch) => {
MatchDispatch::new(dispatch)
.if_notification(async |notification: SessionNotification| {
tracing::debug!(?notification, "received session notification");
if let Some(observer) = &observer {
observer.on_notification(¬ification);
}
Ok(())
})
.await
.if_request(
async |request: RequestPermissionRequest, responder| {
tracing::debug!(
?request,
"received tool use permission request"
);
if let Some(observer) = &observer {
observer.on_permission_request(&request);
}
let option =
request.options.iter().find(|o| match o.kind {
PermissionOptionKind::AllowOnce
| PermissionOptionKind::AllowAlways => true,
PermissionOptionKind::RejectOnce
| PermissionOptionKind::RejectAlways => false,
_ => false,
});
let outcome = option
.map(|o| {
RequestPermissionOutcome::Selected(
SelectedPermissionOutcome::new(
o.option_id.clone(),
),
)
})
.unwrap_or(RequestPermissionOutcome::Cancelled);
responder.respond(RequestPermissionResponse::new(outcome))
},
)
.await
.otherwise_ignore()?
}
_ => continue,
}
}
Ok(())
})
.await?;
if output.is_some() {
info!("think block completed successfully");
} else {
warn!("think block completed but no result was returned");
}
output.ok_or(Error::NoResult)
})
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct ReturnResultInput<T> {
result: T,
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
struct ReturnResultOutput {
success: bool,
}