use std::cell::Cell;
use agent_client_protocol::{self as acp, Client as _};
use serde_json::json;
use tokio::sync::{mpsc, oneshot};
use tokio_util::compat::{TokioAsyncReadCompatExt as _, TokioAsyncWriteCompatExt as _};
struct ExampleAgent {
session_update_tx: mpsc::UnboundedSender<(acp::SessionNotification, oneshot::Sender<()>)>,
next_session_id: Cell<u64>,
}
impl ExampleAgent {
fn new(
session_update_tx: mpsc::UnboundedSender<(acp::SessionNotification, oneshot::Sender<()>)>,
) -> Self {
Self {
session_update_tx,
next_session_id: Cell::new(0),
}
}
}
#[async_trait::async_trait(?Send)]
impl acp::Agent for ExampleAgent {
async fn initialize(
&self,
arguments: acp::InitializeRequest,
) -> Result<acp::InitializeResponse, acp::Error> {
log::info!("Received initialize request {arguments:?}");
Ok(acp::InitializeResponse::new(acp::ProtocolVersion::V1)
.agent_info(acp::Implementation::new("example-agent", "0.1.0").title("Example Agent")))
}
async fn authenticate(
&self,
arguments: acp::AuthenticateRequest,
) -> Result<acp::AuthenticateResponse, acp::Error> {
log::info!("Received authenticate request {arguments:?}");
Ok(acp::AuthenticateResponse::default())
}
async fn new_session(
&self,
arguments: acp::NewSessionRequest,
) -> Result<acp::NewSessionResponse, acp::Error> {
log::info!("Received new session request {arguments:?}");
let session_id = self.next_session_id.get();
self.next_session_id.set(session_id + 1);
Ok(acp::NewSessionResponse::new(session_id.to_string()))
}
async fn load_session(
&self,
arguments: acp::LoadSessionRequest,
) -> Result<acp::LoadSessionResponse, acp::Error> {
log::info!("Received load session request {arguments:?}");
Ok(acp::LoadSessionResponse::new())
}
async fn prompt(
&self,
arguments: acp::PromptRequest,
) -> Result<acp::PromptResponse, acp::Error> {
log::info!("Received prompt request {arguments:?}");
for content in ["Client sent: ".into()].into_iter().chain(arguments.prompt) {
let (tx, rx) = oneshot::channel();
self.session_update_tx
.send((
acp::SessionNotification::new(
arguments.session_id.clone(),
acp::SessionUpdate::AgentMessageChunk(acp::ContentChunk::new(content)),
),
tx,
))
.map_err(|_| acp::Error::internal_error())?;
rx.await.map_err(|_| acp::Error::internal_error())?;
}
Ok(acp::PromptResponse::new(acp::StopReason::EndTurn))
}
async fn cancel(&self, args: acp::CancelNotification) -> Result<(), acp::Error> {
log::info!("Received cancel request {args:?}");
Ok(())
}
async fn set_session_mode(
&self,
args: acp::SetSessionModeRequest,
) -> Result<acp::SetSessionModeResponse, acp::Error> {
log::info!("Received set session mode request {args:?}");
Ok(acp::SetSessionModeResponse::default())
}
#[cfg(feature = "unstable_session_model")]
async fn set_session_model(
&self,
args: acp::SetSessionModelRequest,
) -> Result<acp::SetSessionModelResponse, acp::Error> {
log::info!("Received select model request {args:?}");
Ok(acp::SetSessionModelResponse::default())
}
async fn set_session_config_option(
&self,
args: acp::SetSessionConfigOptionRequest,
) -> Result<acp::SetSessionConfigOptionResponse, acp::Error> {
log::info!("Received set session config option request {args:?}");
let value = args
.value
.as_value_id()
.ok_or(acp::Error::invalid_params())?
.clone();
let option = acp::SessionConfigOption::select(
args.config_id,
"Example Option",
value,
vec![
acp::SessionConfigSelectOption::new("option1", "Option 1"),
acp::SessionConfigSelectOption::new("option2", "Option 2"),
],
);
Ok(acp::SetSessionConfigOptionResponse::new(vec![option]))
}
async fn ext_method(&self, args: acp::ExtRequest) -> Result<acp::ExtResponse, acp::Error> {
log::info!(
"Received extension method call: method={}, params={:?}",
args.method,
args.params
);
Ok(acp::ExtResponse::new(
serde_json::value::to_raw_value(&json!({"example": "response"}))?.into(),
))
}
async fn ext_notification(&self, args: acp::ExtNotification) -> Result<(), acp::Error> {
log::info!(
"Received extension notification: method={}, params={:?}",
args.method,
args.params
);
Ok(())
}
}
#[tokio::main(flavor = "current_thread")]
async fn main() -> acp::Result<()> {
env_logger::init();
let outgoing = tokio::io::stdout().compat_write();
let incoming = tokio::io::stdin().compat();
let local_set = tokio::task::LocalSet::new();
local_set
.run_until(async move {
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let (conn, handle_io) =
acp::AgentSideConnection::new(ExampleAgent::new(tx), outgoing, incoming, |fut| {
tokio::task::spawn_local(fut);
});
tokio::task::spawn_local(async move {
while let Some((session_notification, tx)) = rx.recv().await {
let result = conn.session_notification(session_notification).await;
if let Err(e) = result {
log::error!("{e}");
break;
}
tx.send(()).ok();
}
});
handle_io.await
})
.await
}