1use anyhow::Result;
2use aster::agents::extension::{Envs, PlatformExtensionContext, PLATFORM_EXTENSIONS};
3use aster::agents::{Agent, ExtensionConfig, SessionConfig};
4use aster::config::{get_all_extensions, Config};
5use aster::conversation::message::{ActionRequiredData, Message, MessageContent};
6use aster::conversation::Conversation;
7use aster::mcp_utils::ToolResult;
8use aster::permission::permission_confirmation::PrincipalType;
9use aster::permission::{Permission, PermissionConfirmation};
10use aster::providers::create;
11use aster::session::session_manager::SessionType;
12use aster::session::SessionManager;
13use rmcp::model::{CallToolResult, RawContent, ResourceContents, Role};
14use sacp::schema::{
15 AgentCapabilities, AuthenticateRequest, AuthenticateResponse, BlobResourceContents,
16 CancelNotification, Content, ContentBlock, ContentChunk, EmbeddedResource,
17 EmbeddedResourceResource, ImageContent, InitializeRequest, InitializeResponse,
18 LoadSessionRequest, LoadSessionResponse, McpCapabilities, McpServer, NewSessionRequest,
19 NewSessionResponse, PermissionOption, PermissionOptionId, PermissionOptionKind,
20 PromptCapabilities, PromptRequest, PromptResponse, RequestPermissionOutcome,
21 RequestPermissionRequest, ResourceLink, SessionId, SessionNotification, SessionUpdate,
22 StopReason, TextContent, TextResourceContents, ToolCall, ToolCallContent, ToolCallId,
23 ToolCallLocation, ToolCallStatus, ToolCallUpdate, ToolCallUpdateFields, ToolKind,
24};
25use sacp::{AgentToClient, ByteStreams, Handled, JrConnectionCx, JrMessageHandler, MessageCx};
26use std::collections::{HashMap, HashSet};
27use std::fs;
28use std::sync::Arc;
29use tokio::sync::Mutex;
30use tokio::task::JoinSet;
31use tokio_util::compat::{TokioAsyncReadCompatExt as _, TokioAsyncWriteCompatExt as _};
32use tokio_util::sync::CancellationToken;
33use tracing::{debug, error, info, warn};
34use url::Url;
35
36struct AsterAcpSession {
37 messages: Conversation,
38 tool_requests: HashMap<String, aster::conversation::message::ToolRequest>,
39 cancel_token: Option<CancellationToken>,
40}
41
42struct AsterAcpAgent {
43 sessions: Arc<Mutex<HashMap<String, AsterAcpSession>>>,
44 agent: Arc<Agent>,
45}
46
47fn mcp_server_to_extension_config(mcp_server: McpServer) -> Result<ExtensionConfig, String> {
48 match mcp_server {
49 McpServer::Stdio(stdio) => Ok(ExtensionConfig::Stdio {
50 name: stdio.name,
51 description: String::new(),
52 cmd: stdio.command.to_string_lossy().to_string(),
53 args: stdio.args,
54 envs: Envs::new(stdio.env.into_iter().map(|e| (e.name, e.value)).collect()),
55 env_keys: vec![],
56 timeout: None,
57 bundled: Some(false),
58 available_tools: vec![],
59 }),
60 McpServer::Http(http) => Ok(ExtensionConfig::StreamableHttp {
61 name: http.name,
62 description: String::new(),
63 uri: http.url,
64 envs: Envs::default(),
65 env_keys: vec![],
66 headers: http
67 .headers
68 .into_iter()
69 .map(|h| (h.name, h.value))
70 .collect(),
71 timeout: None,
72 bundled: Some(false),
73 available_tools: vec![],
74 }),
75 McpServer::Sse(_) => Err("SSE is unsupported, migrate to streamable_http".to_string()),
76 _ => Err("Unknown MCP server type".to_string()),
77 }
78}
79
80fn create_tool_location(path: &str, line: Option<u32>) -> ToolCallLocation {
81 let mut loc = ToolCallLocation::new(path);
82 if let Some(l) = line {
83 loc = loc.line(l);
84 }
85 loc
86}
87
88fn extract_tool_locations(
89 tool_request: &aster::conversation::message::ToolRequest,
90 tool_response: &aster::conversation::message::ToolResponse,
91) -> Vec<ToolCallLocation> {
92 let mut locations = Vec::new();
93
94 if let Ok(tool_call) = &tool_request.tool_call {
96 if tool_call.name != "developer__text_editor" {
98 return locations;
99 }
100
101 let path_str = tool_call
103 .arguments
104 .as_ref()
105 .and_then(|args| args.get("path"))
106 .and_then(|p| p.as_str());
107
108 if let Some(path_str) = path_str {
109 let command = tool_call
111 .arguments
112 .as_ref()
113 .and_then(|args| args.get("command"))
114 .and_then(|c| c.as_str());
115
116 if let Ok(result) = &tool_response.tool_result {
118 for content in &result.content {
119 if let RawContent::Text(text_content) = &content.raw {
120 let text = &text_content.text;
121
122 match command {
124 Some("view") => {
125 let line = extract_view_line_range(text)
127 .map(|range| range.0 as u32)
128 .or(Some(1));
129 locations.push(create_tool_location(path_str, line));
130 }
131 Some("str_replace") | Some("insert") => {
132 let line = extract_first_line_number(text)
134 .map(|l| l as u32)
135 .or(Some(1));
136 locations.push(create_tool_location(path_str, line));
137 }
138 Some("write") => {
139 locations.push(create_tool_location(path_str, Some(1)));
141 }
142 _ => {
143 locations.push(create_tool_location(path_str, Some(1)));
145 }
146 }
147 break; }
149 }
150 }
151
152 if locations.is_empty() {
154 locations.push(create_tool_location(path_str, Some(1)));
155 }
156 }
157 }
158
159 locations
160}
161
162fn extract_view_line_range(text: &str) -> Option<(usize, usize)> {
163 let re = regex::Regex::new(r"\(lines (\d+)-(\d+|end)\)").ok()?;
165 if let Some(caps) = re.captures(text) {
166 let start = caps.get(1)?.as_str().parse::<usize>().ok()?;
167 let end = if caps.get(2)?.as_str() == "end" {
168 start } else {
170 caps.get(2)?.as_str().parse::<usize>().ok()?
171 };
172 return Some((start, end));
173 }
174 None
175}
176
177fn extract_first_line_number(text: &str) -> Option<usize> {
178 let re = regex::Regex::new(r"```[^\n]*\n(\d+):").ok()?;
180 if let Some(caps) = re.captures(text) {
181 return caps.get(1)?.as_str().parse::<usize>().ok();
182 }
183 None
184}
185
186fn read_resource_link(link: ResourceLink) -> Option<String> {
187 let url = Url::parse(&link.uri).ok()?;
188 if url.scheme() == "file" {
189 let path = url.to_file_path().ok()?;
190 let contents = fs::read_to_string(&path).ok()?;
191
192 Some(format!(
193 "\n\n# {}\n```\n{}\n```",
194 path.to_string_lossy(),
195 contents
196 ))
197 } else {
198 None
199 }
200}
201
202fn format_tool_name(tool_name: &str) -> String {
203 if let Some((extension, tool)) = tool_name.split_once("__") {
204 let formatted_extension = extension.replace('_', " ");
205 let formatted_tool = tool.replace('_', " ");
206
207 let capitalize = |s: &str| {
209 s.split_whitespace()
210 .map(|word| {
211 let mut chars = word.chars();
212 match chars.next() {
213 None => String::new(),
214 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
215 }
216 })
217 .collect::<Vec<_>>()
218 .join(" ")
219 };
220
221 format!(
222 "{}: {}",
223 capitalize(&formatted_extension),
224 capitalize(&formatted_tool)
225 )
226 } else {
227 let formatted = tool_name.replace('_', " ");
229 formatted
230 .split_whitespace()
231 .map(|word| {
232 let mut chars = word.chars();
233 match chars.next() {
234 None => String::new(),
235 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
236 }
237 })
238 .collect::<Vec<_>>()
239 .join(" ")
240 }
241}
242
243async fn add_builtins(agent: &Agent, builtins: Vec<String>) {
244 for builtin in builtins {
245 let config = if PLATFORM_EXTENSIONS.contains_key(builtin.as_str()) {
246 ExtensionConfig::Platform {
247 name: builtin.clone(),
248 bundled: None,
249 description: builtin.clone(),
250 available_tools: Vec::new(),
251 }
252 } else {
253 ExtensionConfig::Builtin {
254 name: builtin.clone(),
255 display_name: None,
256 timeout: None,
257 bundled: None,
258 description: builtin.clone(),
259 available_tools: Vec::new(),
260 }
261 };
262 match agent.add_extension(config).await {
263 Ok(_) => info!(extension = %builtin, "builtin extension loaded"),
264 Err(e) => warn!(extension = %builtin, error = %e, "builtin extension load failed"),
265 }
266 }
267}
268
269impl AsterAcpAgent {
270 async fn new(builtins: Vec<String>) -> Result<Self> {
271 let config = Config::global();
272
273 let provider_name: String = config
274 .get_aster_provider()
275 .map_err(|e| anyhow::anyhow!("No provider configured: {}", e))?;
276
277 let model_name: String = config
278 .get_aster_model()
279 .map_err(|e| anyhow::anyhow!("No model configured: {}", e))?;
280
281 let model_config = aster::model::ModelConfig {
282 model_name: model_name.clone(),
283 context_limit: None,
284 temperature: None,
285 max_tokens: None,
286 toolshim: false,
287 toolshim_model: None,
288 fast_model: None,
289 };
290 let provider = create(&provider_name, model_config).await?;
291
292 let session = SessionManager::create_session(
293 std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
294 "ACP Session".to_string(),
295 SessionType::Hidden,
296 )
297 .await?;
298
299 let agent = Agent::new();
300 agent.update_provider(provider.clone(), &session.id).await?;
301
302 let extensions_to_run: Vec<_> = get_all_extensions()
303 .into_iter()
304 .filter(|ext| ext.enabled)
305 .map(|ext| ext.config)
306 .collect();
307
308 let agent_ptr = Arc::new(agent);
309
310 agent_ptr
312 .extension_manager
313 .set_context(PlatformExtensionContext {
314 session_id: Some(session.id.clone()),
315 extension_manager: Some(Arc::downgrade(&agent_ptr.extension_manager)),
316 })
317 .await;
318
319 let mut set = JoinSet::new();
320 let mut waiting_on = HashSet::new();
321
322 for extension in extensions_to_run {
323 waiting_on.insert(extension.name());
324 let agent_ptr_clone = agent_ptr.clone();
325 set.spawn(async move {
326 (
327 extension.name(),
328 agent_ptr_clone.add_extension(extension.clone()).await,
329 )
330 });
331 }
332
333 while let Some(result) = set.join_next().await {
334 match result {
335 Ok((name, Ok(_))) => {
336 waiting_on.remove(&name);
337 info!(extension = %name, "extension loaded");
338 }
339 Ok((name, Err(e))) => {
340 warn!(extension = %name, error = %e, "extension load failed");
341 waiting_on.remove(&name);
342 }
343 Err(e) => {
344 error!(error = %e, "extension task error");
345 }
346 }
347 }
348
349 add_builtins(&agent_ptr, builtins).await;
350
351 Ok(Self {
352 sessions: Arc::new(Mutex::new(HashMap::new())),
353 agent: agent_ptr,
354 })
355 }
356
357 fn convert_acp_prompt_to_message(&self, prompt: Vec<ContentBlock>) -> Message {
358 let mut user_message = Message::user();
359
360 for block in prompt {
362 match block {
363 ContentBlock::Text(text) => {
364 user_message = user_message.with_text(&text.text);
365 }
366 ContentBlock::Image(image) => {
367 user_message = user_message.with_image(&image.data, &image.mime_type);
370 }
371 ContentBlock::Resource(resource) => {
372 match &resource.resource {
374 EmbeddedResourceResource::TextResourceContents(text_resource) => {
375 let header = format!("--- Resource: {} ---\n", text_resource.uri);
376 let content = format!("{}{}\n---\n", header, text_resource.text);
377 user_message = user_message.with_text(&content);
378 }
379 _ => {
380 }
382 }
383 }
384 ContentBlock::ResourceLink(link) => {
385 if let Some(text) = read_resource_link(link) {
386 user_message = user_message.with_text(text)
387 }
388 }
389 ContentBlock::Audio(..) => (),
390 _ => (), }
392 }
393
394 user_message
395 }
396
397 async fn handle_message_content(
398 &self,
399 content_item: &MessageContent,
400 session_id: &SessionId,
401 session: &mut AsterAcpSession,
402 cx: &JrConnectionCx<AgentToClient>,
403 ) -> Result<(), sacp::Error> {
404 match content_item {
405 MessageContent::Text(text) => {
406 cx.send_notification(SessionNotification::new(
408 session_id.clone(),
409 SessionUpdate::AgentMessageChunk(ContentChunk::new(ContentBlock::Text(
410 TextContent::new(&text.text),
411 ))),
412 ))?;
413 }
414 MessageContent::ToolRequest(tool_request) => {
415 self.handle_tool_request(tool_request, session_id, session, cx)
416 .await?;
417 }
418 MessageContent::ToolResponse(tool_response) => {
419 self.handle_tool_response(tool_response, session_id, session, cx)
420 .await?;
421 }
422 MessageContent::Thinking(thinking) => {
423 cx.send_notification(SessionNotification::new(
425 session_id.clone(),
426 SessionUpdate::AgentThoughtChunk(ContentChunk::new(ContentBlock::Text(
427 TextContent::new(&thinking.thinking),
428 ))),
429 ))?;
430 }
431 MessageContent::ActionRequired(action_required) => {
432 if let ActionRequiredData::ToolConfirmation {
433 id,
434 tool_name,
435 arguments,
436 prompt,
437 } = &action_required.data
438 {
439 self.handle_tool_permission_request(
440 id.clone(),
441 tool_name.clone(),
442 arguments.clone(),
443 prompt.clone(),
444 session_id,
445 cx,
446 )?;
447 }
448 }
449 _ => {
450 }
452 }
453 Ok(())
454 }
455
456 async fn handle_tool_request(
457 &self,
458 tool_request: &aster::conversation::message::ToolRequest,
459 session_id: &SessionId,
460 session: &mut AsterAcpSession,
461 cx: &JrConnectionCx<AgentToClient>,
462 ) -> Result<(), sacp::Error> {
463 session
465 .tool_requests
466 .insert(tool_request.id.clone(), tool_request.clone());
467
468 let tool_name = match &tool_request.tool_call {
470 Ok(tool_call) => tool_call.name.to_string(),
471 Err(_) => "error".to_string(),
472 };
473
474 cx.send_notification(SessionNotification::new(
476 session_id.clone(),
477 SessionUpdate::ToolCall(
478 ToolCall::new(
479 ToolCallId::new(tool_request.id.clone()),
480 format_tool_name(&tool_name),
481 )
482 .status(ToolCallStatus::Pending),
483 ),
484 ))?;
485
486 Ok(())
487 }
488
489 async fn handle_tool_response(
490 &self,
491 tool_response: &aster::conversation::message::ToolResponse,
492 session_id: &SessionId,
493 session: &mut AsterAcpSession,
494 cx: &JrConnectionCx<AgentToClient>,
495 ) -> Result<(), sacp::Error> {
496 let status = if tool_response.tool_result.is_ok() {
498 ToolCallStatus::Completed
499 } else {
500 ToolCallStatus::Failed
501 };
502
503 let content = build_tool_call_content(&tool_response.tool_result);
504
505 let locations = if let Some(tool_request) = session.tool_requests.get(&tool_response.id) {
507 extract_tool_locations(tool_request, tool_response)
508 } else {
509 Vec::new()
510 };
511
512 let mut fields = ToolCallUpdateFields::new().status(status).content(content);
514 if !locations.is_empty() {
515 fields = fields.locations(locations);
516 }
517 cx.send_notification(SessionNotification::new(
518 session_id.clone(),
519 SessionUpdate::ToolCallUpdate(ToolCallUpdate::new(
520 ToolCallId::new(tool_response.id.clone()),
521 fields,
522 )),
523 ))?;
524
525 Ok(())
526 }
527
528 fn handle_tool_permission_request(
529 &self,
530 request_id: String,
531 tool_name: String,
532 arguments: serde_json::Map<String, serde_json::Value>,
533 prompt: Option<String>,
534 session_id: &SessionId,
535 cx: &JrConnectionCx<AgentToClient>,
536 ) -> Result<(), sacp::Error> {
537 let cx = cx.clone();
538 let agent = self.agent.clone();
539 let session_id = session_id.clone();
540
541 let formatted_name = format_tool_name(&tool_name);
542
543 let mut fields = ToolCallUpdateFields::new()
545 .title(formatted_name)
546 .kind(ToolKind::default())
547 .status(ToolCallStatus::Pending)
548 .raw_input(serde_json::Value::Object(arguments));
549 if let Some(p) = prompt {
550 fields = fields.content(vec![ToolCallContent::Content(Content::new(
551 ContentBlock::Text(TextContent::new(p)),
552 ))]);
553 }
554 let tool_call_update = ToolCallUpdate::new(ToolCallId::new(request_id.clone()), fields);
555
556 fn option(kind: PermissionOptionKind) -> PermissionOption {
557 let id = serde_json::to_value(kind)
558 .unwrap()
559 .as_str()
560 .unwrap()
561 .to_string();
562 PermissionOption::new(PermissionOptionId::from(id.clone()), id, kind)
563 }
564 let options = vec![
565 option(PermissionOptionKind::AllowAlways),
566 option(PermissionOptionKind::AllowOnce),
567 option(PermissionOptionKind::RejectOnce),
568 ];
569
570 let permission_request =
571 RequestPermissionRequest::new(session_id, tool_call_update, options);
572
573 cx.send_request(permission_request)
574 .on_receiving_result(move |result| async move {
575 match result {
576 Ok(response) => {
577 agent
578 .handle_confirmation(
579 request_id,
580 outcome_to_confirmation(&response.outcome),
581 )
582 .await;
583 Ok(())
584 }
585 Err(e) => {
586 error!(error = ?e, "permission request failed");
587 agent
588 .handle_confirmation(
589 request_id,
590 PermissionConfirmation {
591 principal_type: PrincipalType::Tool,
592 permission: Permission::Cancel,
593 },
594 )
595 .await;
596 Ok(())
597 }
598 }
599 })?;
600
601 Ok(())
602 }
603}
604
605fn outcome_to_confirmation(outcome: &RequestPermissionOutcome) -> PermissionConfirmation {
606 let permission = match outcome {
607 RequestPermissionOutcome::Cancelled => Permission::Cancel,
608 RequestPermissionOutcome::Selected(selected) => {
609 match serde_json::from_value::<PermissionOptionKind>(serde_json::Value::String(
610 selected.option_id.to_string(),
611 )) {
612 Ok(PermissionOptionKind::AllowAlways) => Permission::AlwaysAllow,
613 Ok(PermissionOptionKind::AllowOnce) => Permission::AllowOnce,
614 Ok(PermissionOptionKind::RejectOnce | PermissionOptionKind::RejectAlways) => {
615 Permission::DenyOnce
616 }
617 Ok(_) => Permission::Cancel, Err(_) => Permission::Cancel,
619 }
620 }
621 _ => Permission::Cancel, };
623 PermissionConfirmation {
624 principal_type: PrincipalType::Tool,
625 permission,
626 }
627}
628
629fn build_tool_call_content(tool_result: &ToolResult<CallToolResult>) -> Vec<ToolCallContent> {
630 match tool_result {
631 Ok(result) => result
632 .content
633 .iter()
634 .filter_map(|content| match &content.raw {
635 RawContent::Text(val) => Some(ToolCallContent::Content(Content::new(
636 ContentBlock::Text(TextContent::new(&val.text)),
637 ))),
638 RawContent::Image(val) => Some(ToolCallContent::Content(Content::new(
639 ContentBlock::Image(ImageContent::new(&val.data, &val.mime_type)),
640 ))),
641 RawContent::Resource(val) => {
642 let resource = match &val.resource {
643 ResourceContents::TextResourceContents {
644 mime_type,
645 text,
646 uri,
647 ..
648 } => {
649 let mut r = TextResourceContents::new(text.clone(), uri.clone());
650 if let Some(mt) = mime_type {
651 r = r.mime_type(mt.clone());
652 }
653 EmbeddedResourceResource::TextResourceContents(r)
654 }
655 ResourceContents::BlobResourceContents {
656 mime_type,
657 blob,
658 uri,
659 ..
660 } => {
661 let mut r = BlobResourceContents::new(blob.clone(), uri.clone());
662 if let Some(mt) = mime_type {
663 r = r.mime_type(mt.clone());
664 }
665 EmbeddedResourceResource::BlobResourceContents(r)
666 }
667 };
668 Some(ToolCallContent::Content(Content::new(
669 ContentBlock::Resource(EmbeddedResource::new(resource)),
670 )))
671 }
672 RawContent::Audio(_) => {
673 None
675 }
676 RawContent::ResourceLink(_) => {
677 None
679 }
680 })
681 .collect(),
682 Err(_) => Vec::new(),
683 }
684}
685
686impl AsterAcpAgent {
687 async fn on_initialize(
688 &self,
689 args: InitializeRequest,
690 ) -> Result<InitializeResponse, sacp::Error> {
691 debug!(?args, "initialize request");
692
693 let capabilities = AgentCapabilities::new()
695 .load_session(true)
696 .prompt_capabilities(
697 PromptCapabilities::new()
698 .image(true)
699 .audio(false)
700 .embedded_context(true),
701 )
702 .mcp_capabilities(McpCapabilities::new().http(true));
703 Ok(InitializeResponse::new(args.protocol_version).agent_capabilities(capabilities))
704 }
705
706 async fn on_new_session(
707 &self,
708 args: NewSessionRequest,
709 ) -> Result<NewSessionResponse, sacp::Error> {
710 debug!(?args, "new session request");
711
712 let aster_session = SessionManager::create_session(
713 std::env::current_dir().unwrap_or_default(),
714 "ACP Session".to_string(), SessionType::User,
716 )
717 .await
718 .map_err(|e| {
719 sacp::Error::new(
720 sacp::ErrorCode::InternalError.into(),
721 format!("Failed to create session: {}", e),
722 )
723 })?;
724
725 let session = AsterAcpSession {
726 messages: Conversation::new_unvalidated(Vec::new()),
727 tool_requests: HashMap::new(),
728 cancel_token: None,
729 };
730
731 let mut sessions = self.sessions.lock().await;
732 sessions.insert(aster_session.id.clone(), session);
733
734 for mcp_server in args.mcp_servers {
736 let config = match mcp_server_to_extension_config(mcp_server) {
737 Ok(c) => c,
738 Err(msg) => {
739 return Err(sacp::Error::new(sacp::ErrorCode::InvalidParams.into(), msg));
740 }
741 };
742 let name = config.name().to_string();
743 if let Err(e) = self.agent.add_extension(config).await {
744 return Err(sacp::Error::new(
745 sacp::ErrorCode::InternalError.into(),
746 format!("Failed to add MCP server '{}': {}", name, e),
747 ));
748 }
749 }
750
751 info!(
752 session_id = %aster_session.id,
753 session_type = "acp",
754 "Session started"
755 );
756
757 Ok(NewSessionResponse::new(SessionId::new(aster_session.id)))
758 }
759
760 async fn on_load_session(
761 &self,
762 args: LoadSessionRequest,
763 cx: &JrConnectionCx<AgentToClient>,
764 ) -> Result<LoadSessionResponse, sacp::Error> {
765 debug!(?args, "load session request");
766
767 let session_id = args.session_id.0.to_string();
768
769 let aster_session = SessionManager::get_session(&session_id, true)
770 .await
771 .map_err(|e| {
772 sacp::Error::new(
773 sacp::ErrorCode::InvalidParams.into(),
774 format!("Failed to load session {}: {}", session_id, e),
775 )
776 })?;
777
778 let conversation = aster_session.conversation.ok_or_else(|| {
779 sacp::Error::new(
780 sacp::ErrorCode::InternalError.into(),
781 format!("Session {} has no conversation data", session_id),
782 )
783 })?;
784
785 SessionManager::update_session(&session_id)
786 .working_dir(args.cwd.clone())
787 .apply()
788 .await
789 .map_err(|e| {
790 sacp::Error::new(
791 sacp::ErrorCode::InternalError.into(),
792 format!("Failed to update session working directory: {}", e),
793 )
794 })?;
795
796 let mut session = AsterAcpSession {
797 messages: conversation.clone(),
798 tool_requests: HashMap::new(),
799 cancel_token: None,
800 };
801
802 for message in conversation.messages() {
804 if !message.metadata.user_visible {
806 continue;
807 }
808
809 for content_item in &message.content {
810 match content_item {
811 MessageContent::Text(text) => {
812 let chunk =
813 ContentChunk::new(ContentBlock::Text(TextContent::new(&text.text)));
814 let update = match message.role {
815 Role::User => SessionUpdate::UserMessageChunk(chunk),
816 Role::Assistant => SessionUpdate::AgentMessageChunk(chunk),
817 };
818 cx.send_notification(SessionNotification::new(
819 args.session_id.clone(),
820 update,
821 ))?;
822 }
823 MessageContent::ToolRequest(tool_request) => {
824 self.handle_tool_request(tool_request, &args.session_id, &mut session, cx)
825 .await?;
826 }
827 MessageContent::ToolResponse(tool_response) => {
828 self.handle_tool_response(
829 tool_response,
830 &args.session_id,
831 &mut session,
832 cx,
833 )
834 .await?;
835 }
836 MessageContent::Thinking(thinking) => {
837 cx.send_notification(SessionNotification::new(
838 args.session_id.clone(),
839 SessionUpdate::AgentThoughtChunk(ContentChunk::new(
840 ContentBlock::Text(TextContent::new(&thinking.thinking)),
841 )),
842 ))?;
843 }
844 _ => {
845 }
847 }
848 }
849 }
850
851 let mut sessions = self.sessions.lock().await;
852 sessions.insert(session_id.clone(), session);
853
854 info!(
855 session_id = %session_id,
856 session_type = "acp",
857 "Session loaded"
858 );
859
860 Ok(LoadSessionResponse::new())
861 }
862
863 async fn on_prompt(
864 &self,
865 args: PromptRequest,
866 cx: &JrConnectionCx<AgentToClient>,
867 ) -> Result<PromptResponse, sacp::Error> {
868 let session_id = args.session_id.0.to_string();
869 let cancel_token = CancellationToken::new();
870
871 {
872 let mut sessions = self.sessions.lock().await;
873 let session = sessions.get_mut(&session_id).ok_or_else(|| {
874 sacp::Error::new(
875 sacp::ErrorCode::InvalidParams.into(),
876 format!("Session not found: {}", session_id),
877 )
878 })?;
879 session.cancel_token = Some(cancel_token.clone());
880 }
881
882 let user_message = self.convert_acp_prompt_to_message(args.prompt);
883
884 let session_config = SessionConfig {
885 id: session_id.clone(),
886 schedule_id: None,
887 max_turns: None,
888 retry_config: None,
889 system_prompt: None,
890 };
891
892 let mut stream = self
893 .agent
894 .reply(user_message, session_config, Some(cancel_token.clone()))
895 .await
896 .map_err(|e| {
897 sacp::Error::new(
898 sacp::ErrorCode::InternalError.into(),
899 format!("Error getting agent reply: {}", e),
900 )
901 })?;
902
903 use futures::StreamExt;
904
905 let mut was_cancelled = false;
906
907 while let Some(event) = stream.next().await {
908 if cancel_token.is_cancelled() {
909 was_cancelled = true;
910 break;
911 }
912
913 match event {
914 Ok(aster::agents::AgentEvent::Message(message)) => {
915 let mut sessions = self.sessions.lock().await;
916 let session = sessions.get_mut(&session_id).ok_or_else(|| {
917 sacp::Error::new(
918 sacp::ErrorCode::InvalidParams.into(),
919 format!("Session not found: {}", session_id),
920 )
921 })?;
922
923 session.messages.push(message.clone());
924
925 for content_item in &message.content {
926 self.handle_message_content(content_item, &args.session_id, session, cx)
927 .await?;
928 }
929 }
930 Ok(_) => {}
931 Err(e) => {
932 return Err(sacp::Error::new(
933 sacp::ErrorCode::InternalError.into(),
934 format!("Error in agent response stream: {}", e),
935 ));
936 }
937 }
938 }
939
940 let mut sessions = self.sessions.lock().await;
941 if let Some(session) = sessions.get_mut(&session_id) {
942 session.cancel_token = None;
943 }
944
945 let stop_reason = if was_cancelled {
946 StopReason::Cancelled
947 } else {
948 StopReason::EndTurn
949 };
950 Ok(PromptResponse::new(stop_reason))
951 }
952
953 async fn on_cancel(&self, args: CancelNotification) -> Result<(), sacp::Error> {
954 debug!(?args, "cancel request");
955
956 let session_id = args.session_id.0.to_string();
957 let mut sessions = self.sessions.lock().await;
958
959 if let Some(session) = sessions.get_mut(&session_id) {
960 if let Some(ref token) = session.cancel_token {
961 info!(session_id = %session_id, "prompt cancelled");
962 token.cancel();
963 }
964 } else {
965 warn!(session_id = %session_id, "cancel request for unknown session");
966 }
967
968 Ok(())
969 }
970}
971
972struct AsterAcpHandler {
973 agent: Arc<AsterAcpAgent>,
974}
975
976impl JrMessageHandler for AsterAcpHandler {
977 type Link = AgentToClient;
978
979 fn describe_chain(&self) -> impl std::fmt::Debug {
980 "aster-acp"
981 }
982
983 async fn handle_message(
984 &mut self,
985 message: MessageCx,
986 cx: JrConnectionCx<AgentToClient>,
987 ) -> Result<Handled<MessageCx>, sacp::Error> {
988 use sacp::util::MatchMessageFrom;
989 use sacp::JrRequestCx;
990
991 MatchMessageFrom::new(message, &cx)
992 .if_request(
993 |req: InitializeRequest, req_cx: JrRequestCx<InitializeResponse>| async {
994 req_cx.respond(self.agent.on_initialize(req).await?)
995 },
996 )
997 .await
998 .if_request(
999 |_req: AuthenticateRequest, req_cx: JrRequestCx<AuthenticateResponse>| async {
1000 req_cx.respond(AuthenticateResponse::new())
1001 },
1002 )
1003 .await
1004 .if_request(
1005 |req: NewSessionRequest, req_cx: JrRequestCx<NewSessionResponse>| async {
1006 req_cx.respond(self.agent.on_new_session(req).await?)
1007 },
1008 )
1009 .await
1010 .if_request(
1011 |req: LoadSessionRequest, req_cx: JrRequestCx<LoadSessionResponse>| async {
1012 req_cx.respond(self.agent.on_load_session(req, &cx).await?)
1013 },
1014 )
1015 .await
1016 .if_request(
1017 |req: PromptRequest, req_cx: JrRequestCx<PromptResponse>| async {
1018 let agent = self.agent.clone();
1021 let cx_clone = cx.clone();
1022 cx.spawn(async move {
1023 match agent.on_prompt(req, &cx_clone).await {
1024 Ok(response) => {
1025 req_cx.respond(response)?;
1026 }
1027 Err(e) => {
1028 req_cx.respond_with_error(e)?;
1029 }
1030 }
1031 Ok(())
1032 })?;
1033 Ok(())
1034 },
1035 )
1036 .await
1037 .if_notification(|notif: CancelNotification| async {
1038 self.agent.on_cancel(notif).await
1039 })
1040 .await
1041 .done()
1042 }
1043}
1044
1045pub async fn run_acp_agent(builtins: Vec<String>) -> Result<()> {
1046 info!("listening on stdio");
1047
1048 let outgoing = tokio::io::stdout().compat_write();
1049 let incoming = tokio::io::stdin().compat();
1050
1051 let agent = Arc::new(AsterAcpAgent::new(builtins).await?);
1052 let handler = AsterAcpHandler { agent };
1053
1054 AgentToClient::builder()
1055 .name("aster-acp")
1056 .with_handler(handler)
1057 .serve(ByteStreams::new(outgoing, incoming))
1058 .await?;
1059
1060 Ok(())
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065 use super::*;
1066 use sacp::schema::{
1067 EnvVariable, HttpHeader, McpServer, McpServerHttp, McpServerSse, McpServerStdio,
1068 ResourceLink, SelectedPermissionOutcome,
1069 };
1070 use std::io::Write;
1071 use tempfile::NamedTempFile;
1072 use test_case::test_case;
1073
1074 use crate::commands::acp::{
1075 format_tool_name, mcp_server_to_extension_config, read_resource_link,
1076 };
1077 use aster::agents::ExtensionConfig;
1078
1079 #[test_case(
1080 McpServer::Stdio(
1081 McpServerStdio::new("github", "/path/to/github-mcp-server")
1082 .args(vec!["stdio".into()])
1083 .env(vec![EnvVariable::new(
1084 "GITHUB_PERSONAL_ACCESS_TOKEN",
1085 "ghp_xxxxxxxxxxxx"
1086 )])
1087 ),
1088 Ok(ExtensionConfig::Stdio {
1089 name: "github".into(),
1090 description: String::new(),
1091 cmd: "/path/to/github-mcp-server".into(),
1092 args: vec!["stdio".into()],
1093 envs: Envs::new(
1094 [(
1095 "GITHUB_PERSONAL_ACCESS_TOKEN".into(),
1096 "ghp_xxxxxxxxxxxx".into()
1097 )]
1098 .into()
1099 ),
1100 env_keys: vec![],
1101 timeout: None,
1102 bundled: Some(false),
1103 available_tools: vec![],
1104 })
1105 )]
1106 #[test_case(
1107 McpServer::Http(
1108 McpServerHttp::new("github", "https://api.githubcopilot.com/mcp/")
1109 .headers(vec![HttpHeader::new("Authorization", "Bearer ghp_xxxxxxxxxxxx")])
1110 ),
1111 Ok(ExtensionConfig::StreamableHttp {
1112 name: "github".into(),
1113 description: String::new(),
1114 uri: "https://api.githubcopilot.com/mcp/".into(),
1115 envs: Envs::default(),
1116 env_keys: vec![],
1117 headers: HashMap::from([(
1118 "Authorization".into(),
1119 "Bearer ghp_xxxxxxxxxxxx".into()
1120 )]),
1121 timeout: None,
1122 bundled: Some(false),
1123 available_tools: vec![],
1124 })
1125 )]
1126 #[test_case(
1127 McpServer::Sse(McpServerSse::new("test-sse", "https://agent-fin.biodnd.com/sse")),
1128 Err("SSE is unsupported, migrate to streamable_http".to_string())
1129 )]
1130 fn test_mcp_server_to_extension_config(
1131 input: McpServer,
1132 expected: Result<ExtensionConfig, String>,
1133 ) {
1134 assert_eq!(mcp_server_to_extension_config(input), expected);
1135 }
1136
1137 fn new_resource_link(content: &str) -> anyhow::Result<(ResourceLink, NamedTempFile)> {
1138 let mut file = NamedTempFile::new()?;
1139 file.write_all(content.as_bytes())?;
1140
1141 let name = file
1142 .path()
1143 .file_name()
1144 .unwrap()
1145 .to_string_lossy()
1146 .to_string();
1147 let uri = format!("file://{}", file.path().to_str().unwrap());
1148 let link = ResourceLink::new(name, uri);
1149 Ok((link, file))
1150 }
1151
1152 #[test]
1153 fn test_read_resource_link_non_file_scheme() {
1154 let (link, file) = new_resource_link("print(\"hello, world\")").unwrap();
1155
1156 let result = read_resource_link(link).unwrap();
1157 let expected = format!(
1158 "
1159
1160# {}
1161```
1162print(\"hello, world\")
1163```",
1164 file.path().to_str().unwrap(),
1165 );
1166
1167 assert_eq!(result, expected,)
1168 }
1169
1170 #[test]
1171 fn test_format_tool_name_with_extension() {
1172 assert_eq!(
1173 format_tool_name("developer__text_editor"),
1174 "Developer: Text Editor"
1175 );
1176 assert_eq!(
1177 format_tool_name("platform__manage_extensions"),
1178 "Platform: Manage Extensions"
1179 );
1180 assert_eq!(format_tool_name("todo__write"), "Todo: Write");
1181 }
1182
1183 #[test]
1184 fn test_format_tool_name_without_extension() {
1185 assert_eq!(format_tool_name("simple_tool"), "Simple Tool");
1186 assert_eq!(format_tool_name("another_name"), "Another Name");
1187 assert_eq!(format_tool_name("single"), "Single");
1188 }
1189
1190 #[test]
1191 fn test_format_tool_name_edge_cases() {
1192 assert_eq!(format_tool_name(""), "");
1193 assert_eq!(format_tool_name("__"), ": ");
1194 assert_eq!(format_tool_name("extension__"), "Extension: ");
1195 assert_eq!(format_tool_name("__tool"), ": Tool");
1196 }
1197
1198 #[test_case(
1199 RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("allow_once")),
1200 PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::AllowOnce };
1201 "allow_once_maps_to_allow_once"
1202 )]
1203 #[test_case(
1204 RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("allow_always")),
1205 PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::AlwaysAllow };
1206 "allow_always_maps_to_always_allow"
1207 )]
1208 #[test_case(
1209 RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("reject_once")),
1210 PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::DenyOnce };
1211 "reject_once_maps_to_deny_once"
1212 )]
1213 #[test_case(
1214 RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("reject_always")),
1215 PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::DenyOnce };
1216 "reject_always_maps_to_deny_once"
1217 )]
1218 #[test_case(
1219 RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new("unknown")),
1220 PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::Cancel };
1221 "unknown_option_maps_to_cancel"
1222 )]
1223 #[test_case(
1224 RequestPermissionOutcome::Cancelled,
1225 PermissionConfirmation { principal_type: PrincipalType::Tool, permission: Permission::Cancel };
1226 "cancelled_maps_to_cancel"
1227 )]
1228 fn test_outcome_to_confirmation(
1229 input: RequestPermissionOutcome,
1230 expected: PermissionConfirmation,
1231 ) {
1232 assert_eq!(outcome_to_confirmation(&input), expected);
1233 }
1234}