use sacp::ClientToAgent;
use sacp::schema::{
AudioContent, ContentBlock, EmbeddedResourceResource, ImageContent, InitializeRequest,
RequestPermissionOutcome, RequestPermissionRequest, RequestPermissionResponse,
SessionNotification, TextContent, VERSION as PROTOCOL_VERSION,
};
use sacp::util::MatchMessage;
use sacp::{Component, Handled, MessageCx, UntypedMessage};
use std::path::PathBuf;
pub fn content_block_to_string(block: &ContentBlock) -> String {
match block {
ContentBlock::Text(TextContent { text, .. }) => text.clone(),
ContentBlock::Image(ImageContent { mime_type, .. }) => {
format!("[Image: {}]", mime_type)
}
ContentBlock::Audio(AudioContent { mime_type, .. }) => {
format!("[Audio: {}]", mime_type)
}
ContentBlock::ResourceLink(link) => link.uri.clone(),
ContentBlock::Resource(resource) => match &resource.resource {
EmbeddedResourceResource::TextResourceContents(text) => text.uri.clone(),
EmbeddedResourceResource::BlobResourceContents(blob) => blob.uri.clone(),
},
}
}
pub async fn prompt_with_callback(
component: impl Component,
prompt_text: impl ToString,
mut callback: impl AsyncFnMut(ContentBlock) + Send,
) -> Result<(), sacp::Error> {
let prompt_text = prompt_text.to_string();
ClientToAgent::builder()
.on_receive_message(
async |message: MessageCx<UntypedMessage, UntypedMessage>, _cx| {
tracing::trace!("received: {:?}", message.message());
Ok(Handled::No {
message,
retry: false,
})
},
)
.connect_to(component)?
.with_client(|cx: sacp::JrConnectionCx<ClientToAgent>| async move {
let _init_response = cx
.send_request(InitializeRequest {
protocol_version: PROTOCOL_VERSION,
client_capabilities: Default::default(),
client_info: None,
meta: None,
})
.block_task()
.await?;
let mut session = cx.build_session(PathBuf::from(".")).send_request().await?;
session.send_prompt(prompt_text)?;
loop {
let update = session.read_update().await?;
match update {
sacp::SessionMessage::SessionMessage(message) => {
MatchMessage::new(message)
.if_notification(async |notification: SessionNotification| {
tracing::debug!(
?notification,
"yopo: received SessionNotification"
);
if let sacp::schema::SessionUpdate::AgentMessageChunk(
content_chunk,
) = notification.update
{
callback(content_chunk.content).await;
}
Ok(())
})
.await
.if_request(async |request: RequestPermissionRequest, request_cx| {
let outcome = request
.options
.iter()
.find(|option| match option.kind {
sacp::schema::PermissionOptionKind::AllowOnce
| sacp::schema::PermissionOptionKind::AllowAlways => true,
sacp::schema::PermissionOptionKind::RejectOnce
| sacp::schema::PermissionOptionKind::RejectAlways => false,
})
.map(|option| RequestPermissionOutcome::Selected {
option_id: option.id.clone(),
})
.unwrap_or(RequestPermissionOutcome::Cancelled);
request_cx.respond(RequestPermissionResponse {
outcome,
meta: None,
})?;
Ok(())
})
.await
.otherwise(async |_msg| Ok(()))
.await?;
}
sacp::SessionMessage::StopReason(stop_reason) => match stop_reason {
sacp::schema::StopReason::EndTurn => break,
sacp::schema::StopReason::MaxTokens => todo!(),
sacp::schema::StopReason::MaxTurnRequests => todo!(),
sacp::schema::StopReason::Refusal => todo!(),
sacp::schema::StopReason::Cancelled => todo!(),
},
_ => todo!(),
}
}
Ok(())
})
.await?;
Ok(())
}
pub async fn prompt(
component: impl Component,
prompt_text: impl ToString,
) -> Result<String, sacp::Error> {
let mut accumulated_text = String::new();
prompt_with_callback(component, prompt_text, async |block| {
let text = content_block_to_string(&block);
accumulated_text.push_str(&text);
})
.await?;
Ok(accumulated_text)
}